Spaces:
Running
Running
Commit
·
73c6377
0
Parent(s):
Initial commit with all files including LFS
Browse files- .dockerignore +15 -0
- .env.example +8 -0
- .gitattributes +4 -0
- .gitignore +212 -0
- API_DOCUMENTATION.md +413 -0
- Dockerfile +50 -0
- README.md +237 -0
- api/__init__.py +0 -0
- api/app.py +118 -0
- api/exceptions.py +86 -0
- api/middleware.py +177 -0
- api/models.py +188 -0
- api/routers/__init__.py +0 -0
- api/routers/auth.py +201 -0
- api/routers/hbv_assessment.py +116 -0
- api/routers/health.py +228 -0
- api/routers/medical.py +294 -0
- api/tempCodeRunnerFile.py +1 -0
- app.py +23 -0
- core/__init__.py +0 -0
- core/agent.py +993 -0
- core/background_init.py +153 -0
- core/config.py +152 -0
- core/context_enrichment.py +342 -0
- core/data_loaders.py +142 -0
- core/github_storage.py +462 -0
- core/hbv_assessment.py +298 -0
- core/medical_terminology.py +288 -0
- core/retrievers.py +200 -0
- core/text_parser.py +194 -0
- core/text_processors.py +20 -0
- core/tools.py +296 -0
- core/tracing.py +104 -0
- core/utils.py +370 -0
- core/validation.py +639 -0
- core/vector_store.py +34 -0
- data/chunks.pkl +3 -0
- data/vector_store/index.faiss +3 -0
- data/vector_store/index.pkl +3 -0
- example_patient_input.json +171 -0
- export_prompts.py +214 -0
- requirements.txt +41 -0
- tempCodeRunnerFile.python +146 -0
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
*.sqlite3
|
| 6 |
+
.env
|
| 7 |
+
.env.*
|
| 8 |
+
.git/
|
| 9 |
+
.gitignore
|
| 10 |
+
.vscode/
|
| 11 |
+
.idea/
|
| 12 |
+
logs/
|
| 13 |
+
data/new_data/
|
| 14 |
+
**/__pycache__/
|
| 15 |
+
**/*.py[cod]
|
.env.example
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenAI API Configuration (Optional - only needed for AI chat feature)
|
| 2 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 3 |
+
|
| 4 |
+
# LangSmith Configuration (Optional - for tracing)
|
| 5 |
+
LANGSMITH_API_KEY=your_langsmith_api_key_here
|
| 6 |
+
LANGSMITH_PROJECT=hbv-ai-assistant
|
| 7 |
+
LANGCHAIN_PROJECT=hbv-ai-assistant
|
| 8 |
+
LANGSMITH_URL=https://api.smith.langchain.com
|
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Spyder project settings
|
| 149 |
+
.spyderproject
|
| 150 |
+
.spyproject
|
| 151 |
+
|
| 152 |
+
# Rope project settings
|
| 153 |
+
.ropeproject
|
| 154 |
+
|
| 155 |
+
# mkdocs documentation
|
| 156 |
+
/site
|
| 157 |
+
|
| 158 |
+
# mypy
|
| 159 |
+
.mypy_cache/
|
| 160 |
+
.dmypy.json
|
| 161 |
+
dmypy.json
|
| 162 |
+
|
| 163 |
+
# Pyre type checker
|
| 164 |
+
.pyre/
|
| 165 |
+
|
| 166 |
+
# pytype static type analyzer
|
| 167 |
+
.pytype/
|
| 168 |
+
|
| 169 |
+
# Cython debug symbols
|
| 170 |
+
cython_debug/
|
| 171 |
+
|
| 172 |
+
# PyCharm
|
| 173 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 174 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 175 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 176 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 177 |
+
#.idea/
|
| 178 |
+
|
| 179 |
+
# Abstra
|
| 180 |
+
# Abstra is an AI-powered process automation framework.
|
| 181 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 182 |
+
# Learn more at https://abstra.io/docs
|
| 183 |
+
.abstra/
|
| 184 |
+
|
| 185 |
+
# Visual Studio Code
|
| 186 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 187 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 188 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 189 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 190 |
+
# .vscode/
|
| 191 |
+
|
| 192 |
+
# Ruff stuff:
|
| 193 |
+
.ruff_cache/
|
| 194 |
+
|
| 195 |
+
# PyPI configuration file
|
| 196 |
+
.pypirc
|
| 197 |
+
|
| 198 |
+
# Cursor
|
| 199 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 200 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 201 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 202 |
+
.cursorignore
|
| 203 |
+
.cursorindexingignore
|
| 204 |
+
|
| 205 |
+
# Marimo
|
| 206 |
+
marimo/_static/
|
| 207 |
+
marimo/_lsp/
|
| 208 |
+
__marimo__/
|
| 209 |
+
|
| 210 |
+
*.ipynb
|
| 211 |
+
*.docx
|
| 212 |
+
*.pptx
|
API_DOCUMENTATION.md
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HBV AI Assistant - API Documentation for Frontend Engineers
|
| 2 |
+
|
| 3 |
+
## Table of Contents
|
| 4 |
+
- [Overview](#overview)
|
| 5 |
+
- [Base URL](#base-url)
|
| 6 |
+
- [Core Endpoints](#core-endpoints)
|
| 7 |
+
- [1. POST /assess](#1-post-assess)
|
| 8 |
+
- [2. POST /assess/text](#2-post-assesstext)
|
| 9 |
+
- [3. POST /ask](#3-post-ask)
|
| 10 |
+
- [4. POST /ask/stream](#4-post-askstream)
|
| 11 |
+
- [5. GET /health](#5-get-health)
|
| 12 |
+
- [Additional Endpoints](#additional-endpoints)
|
| 13 |
+
- [Error Handling](#error-handling)
|
| 14 |
+
- [Code Examples](#code-examples)
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Overview
|
| 19 |
+
|
| 20 |
+
The HBV AI Assistant API provides endpoints for evaluating patient eligibility for HBV (Hepatitis B Virus) treatment according to SASLT 2021 guidelines. The API supports both structured data input and free-form text parsing, along with an interactive AI chat for guideline exploration.
|
| 21 |
+
|
| 22 |
+
**API Version:** 1.0.0
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## Base URL
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
Development: http://127.0.0.1:8000
|
| 30 |
+
Production: [Your production URL]
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Core Endpoints
|
| 36 |
+
|
| 37 |
+
### 1. POST /assess
|
| 38 |
+
|
| 39 |
+
**Primary endpoint for HBV patient eligibility assessment with structured data input.**
|
| 40 |
+
|
| 41 |
+
#### Request
|
| 42 |
+
|
| 43 |
+
**URL:** `/assess`
|
| 44 |
+
**Method:** `POST`
|
| 45 |
+
**Content-Type:** `application/json`
|
| 46 |
+
|
| 47 |
+
**Request Body Schema:**
|
| 48 |
+
|
| 49 |
+
```json
|
| 50 |
+
{
|
| 51 |
+
"sex": "Male", // Required: "Male" or "Female"
|
| 52 |
+
"age": 45, // Required: 0-120
|
| 53 |
+
"pregnancy_status": "Not pregnant", // Required: "Not pregnant" or "Pregnant"
|
| 54 |
+
"hbsag_status": "Positive", // Required: "Positive" or "Negative"
|
| 55 |
+
"duration_hbsag_months": 12, // Required: Duration in months (≥0)
|
| 56 |
+
"hbv_dna_level": 5000, // Required: HBV DNA in IU/mL (≥0)
|
| 57 |
+
"hbeag_status": "Positive", // Required: "Positive" or "Negative"
|
| 58 |
+
"alt_level": 80, // Required: ALT in U/L (≥0)
|
| 59 |
+
"fibrosis_stage": "F2-F3", // Required: "F0-F1", "F2-F3", or "F4"
|
| 60 |
+
"necroinflammatory_activity": "A2", // Required: "A0", "A1", "A2", or "A3"
|
| 61 |
+
"extrahepatic_manifestations": false, // Required: true or false
|
| 62 |
+
"immunosuppression_status": "None", // Optional: "None", "Chemotherapy", "Other"
|
| 63 |
+
"coinfections": [], // Optional: Array of "HIV", "HCV", "HDV"
|
| 64 |
+
"family_history_cirrhosis_hcc": false, // Required: true or false
|
| 65 |
+
"other_comorbidities": [] // Optional: Array of strings
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
#### Response
|
| 70 |
+
|
| 71 |
+
**Status Code:** `200 OK`
|
| 72 |
+
|
| 73 |
+
**Response Body:**
|
| 74 |
+
|
| 75 |
+
```json
|
| 76 |
+
{
|
| 77 |
+
"eligible": true,
|
| 78 |
+
"recommendations": "Based on SASLT 2021 guidelines, this patient IS ELIGIBLE for HBV antiviral treatment.\n\n**Eligibility Criteria Met:**\n- HBeAg-positive chronic hepatitis B with HBV DNA >2,000 IU/mL and ALT elevation (80 U/L) [SASLT 2021, Page 15]\n- Significant fibrosis (F2-F3) with moderate necroinflammatory activity (A2) [SASLT 2021, Page 18]\n\n**Treatment Recommendations:**\n1. **First-line therapy:** Entecavir (ETV) 0.5mg daily or Tenofovir Disoproxil Fumarate (TDF) 300mg daily [SASLT 2021, Page 22]\n2. **Alternative:** Tenofovir Alafenamide (TAF) 25mg daily for patients with renal concerns [SASLT 2021, Page 23]\n\n**Monitoring:**\n- HBV DNA every 3 months during first year\n- ALT and complete blood count every 3-6 months\n- HBeAg/anti-HBe every 6-12 months [SASLT 2021, Page 28]"
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
#### Field Descriptions
|
| 83 |
+
|
| 84 |
+
| Field | Type | Description | Validation |
|
| 85 |
+
|-------|------|-------------|------------|
|
| 86 |
+
| `sex` | string | Patient's biological sex | "Male" or "Female" |
|
| 87 |
+
| `age` | integer | Patient's age in years | 0-120 |
|
| 88 |
+
| `pregnancy_status` | string | Current pregnancy status | "Not pregnant" or "Pregnant" |
|
| 89 |
+
| `hbsag_status` | string | Hepatitis B surface antigen status | "Positive" or "Negative" |
|
| 90 |
+
| `duration_hbsag_months` | integer | How long HBsAg has been positive | ≥0 |
|
| 91 |
+
| `hbv_dna_level` | float | HBV DNA viral load | ≥0 IU/mL |
|
| 92 |
+
| `hbeag_status` | string | Hepatitis B e-antigen status | "Positive" or "Negative" |
|
| 93 |
+
| `alt_level` | float | Alanine aminotransferase level | ≥0 U/L |
|
| 94 |
+
| `fibrosis_stage` | string | Liver fibrosis/cirrhosis stage | "F0-F1", "F2-F3", or "F4" |
|
| 95 |
+
| `necroinflammatory_activity` | string | Degree of liver inflammation | "A0", "A1", "A2", or "A3" |
|
| 96 |
+
| `extrahepatic_manifestations` | boolean | Presence of HBV-related conditions outside liver | true/false |
|
| 97 |
+
| `immunosuppression_status` | string | Current immunosuppression | "None", "Chemotherapy", "Other" |
|
| 98 |
+
| `coinfections` | array | Other viral infections | ["HIV", "HCV", "HDV"] or [] |
|
| 99 |
+
| `family_history_cirrhosis_hcc` | boolean | First-degree relative with cirrhosis/HCC | true/false |
|
| 100 |
+
| `other_comorbidities` | array | Additional medical conditions | Array of strings or [] |
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
### 2. POST /assess/text
|
| 105 |
+
|
| 106 |
+
**Text-based HBV patient eligibility assessment - parses free-form clinical text.**
|
| 107 |
+
|
| 108 |
+
#### Request
|
| 109 |
+
|
| 110 |
+
**URL:** `/assess/text`
|
| 111 |
+
**Method:** `POST`
|
| 112 |
+
**Content-Type:** `application/json`
|
| 113 |
+
|
| 114 |
+
**Request Body Schema:**
|
| 115 |
+
|
| 116 |
+
```json
|
| 117 |
+
{
|
| 118 |
+
"text_input": "45-year-old male patient\nHBsAg: Positive for 12 months\nHBV DNA: 5000 IU/mL\nHBeAg: Positive\nALT: 80 U/L\nFibrosis stage: F2-F3\nNecroinflammatory activity: A2\nNo extrahepatic manifestations\nNo immunosuppression\nNo coinfections (HIV, HCV, HDV)\nNo family history of cirrhosis or HCC"
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
**Field:**
|
| 123 |
+
- `text_input` (string, required): Free-form text containing patient clinical data. Minimum 10 characters.
|
| 124 |
+
|
| 125 |
+
#### Response
|
| 126 |
+
|
| 127 |
+
**Status Code:** `200 OK`
|
| 128 |
+
|
| 129 |
+
**Response Body:** Same as `/assess` endpoint
|
| 130 |
+
|
| 131 |
+
```json
|
| 132 |
+
{
|
| 133 |
+
"eligible": true,
|
| 134 |
+
"recommendations": "Based on SASLT 2021 guidelines, this patient IS ELIGIBLE for HBV antiviral treatment..."
|
| 135 |
+
}
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
#### How It Works
|
| 139 |
+
|
| 140 |
+
1. **LLM-based parsing:** The API uses an AI model to extract structured patient data from the free-form text
|
| 141 |
+
2. **Validation:** Extracted data is validated against the same schema as `/assess`
|
| 142 |
+
3. **Assessment:** Performs identical assessment logic as the structured endpoint
|
| 143 |
+
4. **Response:** Returns the same structured response format
|
| 144 |
+
|
| 145 |
+
#### Text Input Tips
|
| 146 |
+
|
| 147 |
+
- Include all relevant clinical parameters
|
| 148 |
+
- Use clear labels (e.g., "HBV DNA:", "ALT:", "Age:")
|
| 149 |
+
- Can use natural language or structured format
|
| 150 |
+
- The AI will extract and interpret the data intelligently
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
### 3. POST /ask
|
| 155 |
+
|
| 156 |
+
**Interactive AI chat for exploring HBV treatment guidelines.**
|
| 157 |
+
|
| 158 |
+
#### Request
|
| 159 |
+
|
| 160 |
+
**URL:** `/ask`
|
| 161 |
+
**Method:** `POST`
|
| 162 |
+
**Content-Type:** `application/json`
|
| 163 |
+
|
| 164 |
+
**Request Body Schema:**
|
| 165 |
+
|
| 166 |
+
```json
|
| 167 |
+
{
|
| 168 |
+
"query": "What are the first-line treatment options for HBeAg-positive chronic hepatitis B?",
|
| 169 |
+
"session_id": "doctor_123_session_1",
|
| 170 |
+
"patient_context": {
|
| 171 |
+
// Optional: Include HBVPatientInput object for context
|
| 172 |
+
"age": 45,
|
| 173 |
+
"sex": "Male",
|
| 174 |
+
"hbv_dna_level": 5000,
|
| 175 |
+
// ... other patient fields
|
| 176 |
+
},
|
| 177 |
+
"assessment_result": {
|
| 178 |
+
// Optional: Include prior assessment result for context
|
| 179 |
+
"eligible": true,
|
| 180 |
+
"recommendations": "..."
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
**Fields:**
|
| 186 |
+
- `query` (string, required): Doctor's question about HBV guidelines. Max 2000 characters.
|
| 187 |
+
- `session_id` (string, optional): Session identifier for conversation continuity. Default: "default"
|
| 188 |
+
- `patient_context` (object, optional): Patient data from a prior assessment to provide context
|
| 189 |
+
- `assessment_result` (object, optional): Assessment result to provide context
|
| 190 |
+
|
| 191 |
+
#### Response
|
| 192 |
+
|
| 193 |
+
**Status Code:** `200 OK`
|
| 194 |
+
|
| 195 |
+
**Response Body:**
|
| 196 |
+
|
| 197 |
+
```json
|
| 198 |
+
{
|
| 199 |
+
"response": "According to SASLT 2021 guidelines, the first-line treatment options for HBeAg-positive chronic hepatitis B are:\n\n1. **Entecavir (ETV)** - 0.5mg once daily [SASLT 2021, Page 22]\n2. **Tenofovir Disoproxil Fumarate (TDF)** - 300mg once daily [SASLT 2021, Page 22]\n\nBoth medications have:\n- High barrier to resistance\n- Excellent efficacy in viral suppression\n- Well-established safety profiles\n\n**Alternative option:**\n- **Tenofovir Alafenamide (TAF)** - 25mg once daily, particularly for patients with renal concerns or bone disease [SASLT 2021, Page 23]",
|
| 200 |
+
"session_id": "doctor_123_session_1"
|
| 201 |
+
}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
#### Session Management
|
| 205 |
+
|
| 206 |
+
- **Conversation continuity:** Use the same `session_id` to maintain conversation context across multiple questions
|
| 207 |
+
- **New conversation:** Use a different `session_id` or omit it for a fresh conversation
|
| 208 |
+
- **Clear session:** Use `DELETE /session/{session_id}` to clear conversation history
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
### 4. POST /ask/stream
|
| 213 |
+
|
| 214 |
+
**Streaming version of the AI chat endpoint for real-time responses.**
|
| 215 |
+
|
| 216 |
+
#### Request
|
| 217 |
+
|
| 218 |
+
**URL:** `/ask/stream`
|
| 219 |
+
**Method:** `POST`
|
| 220 |
+
**Content-Type:** `application/json`
|
| 221 |
+
|
| 222 |
+
**Request Body:** Same as `/ask` endpoint
|
| 223 |
+
|
| 224 |
+
```json
|
| 225 |
+
{
|
| 226 |
+
"query": "What monitoring is required during HBV treatment?",
|
| 227 |
+
"session_id": "doctor_123_session_1"
|
| 228 |
+
}
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
#### Response
|
| 232 |
+
|
| 233 |
+
**Status Code:** `200 OK`
|
| 234 |
+
**Content-Type:** `text/markdown`
|
| 235 |
+
**Transfer-Encoding:** `chunked`
|
| 236 |
+
|
| 237 |
+
**Response:** Streaming text/markdown chunks
|
| 238 |
+
|
| 239 |
+
The response streams in real-time as the AI generates the answer. This provides a better user experience for longer responses.
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
---
|
| 243 |
+
|
| 244 |
+
### 5. GET /health
|
| 245 |
+
|
| 246 |
+
**Simple health check endpoint to verify API is running.**
|
| 247 |
+
|
| 248 |
+
#### Request
|
| 249 |
+
|
| 250 |
+
**URL:** `/health`
|
| 251 |
+
**Method:** `GET`
|
| 252 |
+
|
| 253 |
+
#### Response
|
| 254 |
+
|
| 255 |
+
**Status Code:** `200 OK`
|
| 256 |
+
|
| 257 |
+
**Response Body:**
|
| 258 |
+
|
| 259 |
+
```json
|
| 260 |
+
{
|
| 261 |
+
"status": "healthy",
|
| 262 |
+
"version": "1.0.0",
|
| 263 |
+
"timestamp": "2024-01-15T10:30:00.123456"
|
| 264 |
+
}
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
**Fields:**
|
| 268 |
+
- `status` (string): API health status ("healthy")
|
| 269 |
+
- `version` (string): API version
|
| 270 |
+
- `timestamp` (string): Current server timestamp in ISO format
|
| 271 |
+
|
| 272 |
+
#### Use Case
|
| 273 |
+
|
| 274 |
+
Use this endpoint to:
|
| 275 |
+
- Verify the API is running and accessible
|
| 276 |
+
- Monitor API availability
|
| 277 |
+
- Check server time for debugging
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
## Additional Endpoints
|
| 282 |
+
|
| 283 |
+
### API Information
|
| 284 |
+
|
| 285 |
+
**GET** `/`
|
| 286 |
+
|
| 287 |
+
**Response:**
|
| 288 |
+
```json
|
| 289 |
+
{
|
| 290 |
+
"name": "HBV AI Assistant API",
|
| 291 |
+
"version": "1.0.0",
|
| 292 |
+
"description": "HBV Patient Selection System - Evaluates patient eligibility for HBV treatment according to SASLT 2021 guidelines",
|
| 293 |
+
"docs": "/docs",
|
| 294 |
+
"endpoints": {
|
| 295 |
+
"assess": "/assess (POST) - Primary endpoint for HBV patient eligibility assessment",
|
| 296 |
+
"assess_text": "/assess/text (POST) - Text-based HBV patient eligibility assessment",
|
| 297 |
+
"ask": "/ask (POST) - Optional AI chat for guideline exploration",
|
| 298 |
+
"ask_stream": "/ask/stream (POST) - Streaming AI chat responses",
|
| 299 |
+
"health": "/health (GET) - Simple health check"
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
## Error Handling
|
| 307 |
+
|
| 308 |
+
### HTTP Status Codes
|
| 309 |
+
|
| 310 |
+
| Code | Description |
|
| 311 |
+
|------|-------------|
|
| 312 |
+
| 200 | Success |
|
| 313 |
+
| 400 | Bad Request - Invalid input data |
|
| 314 |
+
| 422 | Validation Error - Request body doesn't match schema |
|
| 315 |
+
| 429 | Too Many Requests - Rate limit exceeded |
|
| 316 |
+
| 500 | Internal Server Error |
|
| 317 |
+
|
| 318 |
+
### Error Response Format
|
| 319 |
+
|
| 320 |
+
```json
|
| 321 |
+
{
|
| 322 |
+
"detail": "Error message describing what went wrong"
|
| 323 |
+
}
|
| 324 |
+
```
|
| 325 |
+
|
| 326 |
+
### Validation Error Example
|
| 327 |
+
|
| 328 |
+
**Request with invalid data:**
|
| 329 |
+
```json
|
| 330 |
+
{
|
| 331 |
+
"sex": "Unknown", // Invalid: must be "Male" or "Female"
|
| 332 |
+
"age": 150 // Invalid: must be 0-120
|
| 333 |
+
}
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
**Response (422):**
|
| 337 |
+
```json
|
| 338 |
+
{
|
| 339 |
+
"detail": [
|
| 340 |
+
{
|
| 341 |
+
"loc": ["body", "sex"],
|
| 342 |
+
"msg": "Sex must be either Male or Female",
|
| 343 |
+
"type": "value_error"
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"loc": ["body", "age"],
|
| 347 |
+
"msg": "ensure this value is less than or equal to 120",
|
| 348 |
+
"type": "value_error.number.not_le"
|
| 349 |
+
}
|
| 350 |
+
]
|
| 351 |
+
}
|
| 352 |
+
```
|
| 353 |
+
---
|
| 354 |
+
|
| 355 |
+
## Interactive API Documentation
|
| 356 |
+
|
| 357 |
+
For interactive API documentation with a built-in testing interface, visit:
|
| 358 |
+
|
| 359 |
+
**Swagger UI:** `http://127.0.0.1:8000/docs`
|
| 360 |
+
**ReDoc:** `http://127.0.0.1:8000/redoc`
|
| 361 |
+
|
| 362 |
+
These interfaces allow you to:
|
| 363 |
+
- View all endpoints and their schemas
|
| 364 |
+
- Test endpoints directly from the browser
|
| 365 |
+
- See example requests and responses
|
| 366 |
+
- Understand validation rules
|
| 367 |
+
|
| 368 |
+
---
|
| 369 |
+
|
| 370 |
+
## Rate Limiting
|
| 371 |
+
|
| 372 |
+
The API implements rate limiting:
|
| 373 |
+
- **Default limit:** 100 requests per minute per IP address
|
| 374 |
+
- **Response header:** `X-RateLimit-Remaining` shows remaining requests
|
| 375 |
+
- **Status code:** 429 when limit exceeded
|
| 376 |
+
|
| 377 |
+
---
|
| 378 |
+
|
| 379 |
+
## Best Practices
|
| 380 |
+
|
| 381 |
+
1. **Error Handling:**
|
| 382 |
+
- Always check `response.ok` before parsing JSON
|
| 383 |
+
- Display validation errors to users clearly
|
| 384 |
+
- Implement retry logic for 500 errors
|
| 385 |
+
|
| 386 |
+
2. **Session Management:**
|
| 387 |
+
- Use unique session IDs for different patient conversations
|
| 388 |
+
- Clear sessions when switching patients
|
| 389 |
+
- Consider using user ID + timestamp for session IDs
|
| 390 |
+
|
| 391 |
+
3. **Performance:**
|
| 392 |
+
- Use `/ask/stream` for better UX on longer responses
|
| 393 |
+
- Cache assessment results when appropriate
|
| 394 |
+
- Consider debouncing text input for `/assess/text` endpoint
|
| 395 |
+
|
| 396 |
+
4. **Data Validation:**
|
| 397 |
+
- Validate input on frontend before sending to API
|
| 398 |
+
- Use the exact enum values specified in the documentation
|
| 399 |
+
- Handle validation errors gracefully
|
| 400 |
+
|
| 401 |
+
---
|
| 402 |
+
|
| 403 |
+
## Support
|
| 404 |
+
|
| 405 |
+
For questions or issues:
|
| 406 |
+
- Check the interactive documentation at `/docs`
|
| 407 |
+
- Review error messages carefully - they indicate what's wrong
|
| 408 |
+
- Ensure all required fields are provided with correct data types
|
| 409 |
+
|
| 410 |
+
---
|
| 411 |
+
|
| 412 |
+
**Last Updated:** 2024-01-15
|
| 413 |
+
**API Version:** 1.0.0
|
Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========================
|
| 2 |
+
# Stage 1 - Builder
|
| 3 |
+
# ========================
|
| 4 |
+
FROM python:3.11-slim AS builder
|
| 5 |
+
|
| 6 |
+
# Install build dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
build-essential \
|
| 9 |
+
gcc \
|
| 10 |
+
g++ \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
|
| 15 |
+
# Copy only requirements first (better caching)
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install dependencies
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# ========================
|
| 22 |
+
# Stage 2 - Final Runtime Image
|
| 23 |
+
# ========================
|
| 24 |
+
FROM python:3.11-slim
|
| 25 |
+
|
| 26 |
+
# Install minimal runtime dependencies
|
| 27 |
+
RUN apt-get update && apt-get install -y \
|
| 28 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 29 |
+
|
| 30 |
+
# Create a non-root user
|
| 31 |
+
RUN useradd --create-home --shell /bin/bash appuser
|
| 32 |
+
WORKDIR /app
|
| 33 |
+
|
| 34 |
+
# Copy installed packages and application code
|
| 35 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 36 |
+
COPY --from=builder /usr/local/bin /usr/local/bin
|
| 37 |
+
COPY . /app
|
| 38 |
+
|
| 39 |
+
# Set permissions
|
| 40 |
+
RUN chown -R appuser:appuser /app
|
| 41 |
+
USER appuser
|
| 42 |
+
|
| 43 |
+
# Expose port
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
# Set the working directory to the backend folder
|
| 47 |
+
WORKDIR /app
|
| 48 |
+
|
| 49 |
+
# Command to run the FastAPI app
|
| 50 |
+
CMD ["sh", "-c", "uvicorn api.app:app --host 0.0.0.0 --port 7860"]
|
README.md
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HBV AI Assistant - Patient Selection System
|
| 3 |
+
emoji: 🏥
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 7860
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# HBV AI Assistant - Patient Selection System
|
| 12 |
+
|
| 13 |
+
A specialized AI-powered clinical decision support system for hepatologists and healthcare professionals managing Hepatitis B Virus (HBV) patients. The system evaluates patient eligibility for treatment according to SASLT 2021 guidelines and provides evidence-based recommendations.
|
| 14 |
+
|
| 15 |
+
## 🎯 Features
|
| 16 |
+
|
| 17 |
+
### Core Capabilities
|
| 18 |
+
- **Patient Eligibility Assessment**: Evaluates HBV patients for treatment eligibility based on SASLT 2021 guidelines
|
| 19 |
+
- **Evidence-Based Guidance**: Provides treatment recommendations according to authoritative medical guidelines
|
| 20 |
+
- **Comprehensive Input Validation**: Validates patient data including HBV DNA levels, ALT levels, fibrosis stage, and more
|
| 21 |
+
- **Treatment Options**: Recommends preferred regimens (ETV, TDF, TAF) based on patient profile
|
| 22 |
+
- **AI-Powered Chat**: Interactive AI bot for exploring guideline recommendations
|
| 23 |
+
- **JSON API**: RESTful POST endpoint for programmatic integration
|
| 24 |
+
|
| 25 |
+
### Technical Features
|
| 26 |
+
- **FastAPI Backend**: High-performance async API
|
| 27 |
+
- **Structured Input/Output**: Well-defined schemas for patient data and eligibility results
|
| 28 |
+
- **Real-time Processing**: Fast eligibility determination
|
| 29 |
+
- **Authentication**: Secure session-based authentication
|
| 30 |
+
- **Rate Limiting**: Built-in API rate limiting
|
| 31 |
+
- **CORS Support**: Cross-origin resource sharing enabled
|
| 32 |
+
|
| 33 |
+
## 🚀 Deployment
|
| 34 |
+
|
| 35 |
+
### Live API
|
| 36 |
+
The API is deployed at: **http://127.0.0.1:7860**
|
| 37 |
+
|
| 38 |
+
### Quick Start
|
| 39 |
+
|
| 40 |
+
1. **Access the API**:
|
| 41 |
+
- API Docs: http://127.0.0.1:7860/docs
|
| 42 |
+
- Health Check: http://127.0.0.1:7860/health
|
| 43 |
+
|
| 44 |
+
2. **Submit Patient Data**:
|
| 45 |
+
- Use the POST `/assess` endpoint to evaluate patient eligibility
|
| 46 |
+
- Provide patient information according to the input schema
|
| 47 |
+
- Receive eligibility determination and treatment recommendations
|
| 48 |
+
|
| 49 |
+
### Deploy Your Own Instance
|
| 50 |
+
|
| 51 |
+
See [DEPLOYMENT.md](DEPLOYMENT.md) for detailed deployment instructions.
|
| 52 |
+
|
| 53 |
+
## 📚 API Endpoints
|
| 54 |
+
|
| 55 |
+
### Health & Status
|
| 56 |
+
- `GET /` - API information
|
| 57 |
+
- `GET /health` - Health check
|
| 58 |
+
|
| 59 |
+
### HBV Patient Assessment
|
| 60 |
+
- `POST /assess` - Evaluate patient eligibility for HBV treatment using structured data
|
| 61 |
+
- `POST /assess/text` - Text-based patient assessment (provide clinical notes in free text format)
|
| 62 |
+
- **Input**: Patient data (sex, age, HBV DNA, ALT, fibrosis stage, etc.)
|
| 63 |
+
- **Output**: Eligibility status and treatment recommendations
|
| 64 |
+
|
| 65 |
+
### AI Chat
|
| 66 |
+
- `POST /ask` - Ask guideline-related questions with optional patient context
|
| 67 |
+
- **Input**:
|
| 68 |
+
- `query` (string): The question or message
|
| 69 |
+
- `session_id` (string, optional): Session identifier for conversation history
|
| 70 |
+
- `patient_context` (object, optional): Patient data for context-aware responses
|
| 71 |
+
- `assessment_result` (object, optional): Previous assessment results for reference
|
| 72 |
+
|
| 73 |
+
- `POST /ask/stream` - Streaming chat responses with optional patient context
|
| 74 |
+
- **Input**: Same as `/ask` endpoint
|
| 75 |
+
- **Output**: Stream of text chunks for real-time display
|
| 76 |
+
|
| 77 |
+
## 💻 Local Development
|
| 78 |
+
|
| 79 |
+
### Prerequisites
|
| 80 |
+
- Python 3.11+
|
| 81 |
+
- OpenAI API key (optional, for AI chat feature)
|
| 82 |
+
|
| 83 |
+
### Setup
|
| 84 |
+
|
| 85 |
+
1. **Clone the repository**:
|
| 86 |
+
```bash
|
| 87 |
+
git clone https://github.com/your-repo/hbv-ai-assistant.git
|
| 88 |
+
cd hbv-ai-assistant
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
2. **Install dependencies**:
|
| 92 |
+
```bash
|
| 93 |
+
pip install -r requirements.txt
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
3. **Configure environment variables**:
|
| 97 |
+
```bash
|
| 98 |
+
cp .env.example .env
|
| 99 |
+
# Edit .env with your API keys
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
4. **Run the application**:
|
| 103 |
+
```bash
|
| 104 |
+
python app.py
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
5. **Access the application**:
|
| 108 |
+
- API: http://localhost:7860
|
| 109 |
+
- Docs: http://localhost:7860/docs
|
| 110 |
+
- Test the `/assess` endpoint with patient data
|
| 111 |
+
|
| 112 |
+
## 🔧 Configuration
|
| 113 |
+
|
| 114 |
+
### Environment Variables
|
| 115 |
+
|
| 116 |
+
See `.env.example` for all configuration options:
|
| 117 |
+
|
| 118 |
+
- `OPENAI_API_KEY`: Your OpenAI API key (optional, for AI chat)
|
| 119 |
+
- `PORT`: Server port (default: 7860)
|
| 120 |
+
- `ALLOWED_ORIGINS`: CORS allowed origins
|
| 121 |
+
|
| 122 |
+
### Authentication
|
| 123 |
+
|
| 124 |
+
Default credentials (change in production):
|
| 125 |
+
- Username: `admin`
|
| 126 |
+
- Password: `admin123`
|
| 127 |
+
|
| 128 |
+
Update in `api/routers/auth.py` or via environment variables.
|
| 129 |
+
|
| 130 |
+
## 📖 Usage Examples
|
| 131 |
+
|
| 132 |
+
### Assessing Patient Eligibility
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
import requests
|
| 136 |
+
|
| 137 |
+
# Login
|
| 138 |
+
response = requests.post(
|
| 139 |
+
"http://127.0.0.1:7860/auth/login",
|
| 140 |
+
json={"username": "admin", "password": "admin123"}
|
| 141 |
+
)
|
| 142 |
+
cookies = response.cookies
|
| 143 |
+
|
| 144 |
+
# Assess patient eligibility
|
| 145 |
+
patient_data = {
|
| 146 |
+
"sex": "Male",
|
| 147 |
+
"age": 45,
|
| 148 |
+
"pregnancy_status": "Not pregnant",
|
| 149 |
+
"hbsag_status": "Positive",
|
| 150 |
+
"duration_hbsag_months": 12,
|
| 151 |
+
"hbv_dna_level": 50000,
|
| 152 |
+
"hbeag_status": "Positive",
|
| 153 |
+
"alt_level": 60,
|
| 154 |
+
"fibrosis_stage": "F2-F3",
|
| 155 |
+
"necroinflammatory_activity": "A2",
|
| 156 |
+
"extrahepatic_manifestations": False,
|
| 157 |
+
"coinfections": [],
|
| 158 |
+
"family_history_cirrhosis_hcc": False
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
response = requests.post(
|
| 162 |
+
"http://127.0.0.1:7860/assess",
|
| 163 |
+
json=patient_data,
|
| 164 |
+
cookies=cookies
|
| 165 |
+
)
|
| 166 |
+
result = response.json()
|
| 167 |
+
print(f"Eligible: {result['eligible']}")
|
| 168 |
+
print(f"Recommendations: {result['recommendations']}")
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
## 🏗️ Architecture
|
| 172 |
+
|
| 173 |
+
### Components
|
| 174 |
+
|
| 175 |
+
- **FastAPI Backend**: RESTful API with async support
|
| 176 |
+
- **Eligibility Engine**: Evaluates patient data against SASLT 2021 criteria
|
| 177 |
+
- **AI Chat (Optional)**: LangChain-powered conversational interface for guideline exploration
|
| 178 |
+
- **Validation Layer**: Ensures data integrity and completeness
|
| 179 |
+
|
| 180 |
+
### Assessment Logic
|
| 181 |
+
|
| 182 |
+
The system evaluates patients based on SASLT 2021 criteria:
|
| 183 |
+
|
| 184 |
+
1. **HBV DNA > 2,000 IU/mL** + **ALT > ULN** + moderate necroinflammation/fibrosis (≥F2 or ≥A2)
|
| 185 |
+
2. **Cirrhosis** (F4) with any detectable HBV DNA
|
| 186 |
+
3. **HBV DNA > 20,000 IU/mL** + **ALT > 2×ULN** regardless of fibrosis
|
| 187 |
+
4. **Age > 30** with HBeAg-positive chronic infection (normal ALT, high HBV DNA)
|
| 188 |
+
5. **Family history** of HCC/cirrhosis + HBV DNA > 2,000 + ALT > ULN
|
| 189 |
+
6. **Extrahepatic manifestations**
|
| 190 |
+
|
| 191 |
+
## 📊 Response Format
|
| 192 |
+
|
| 193 |
+
The API returns:
|
| 194 |
+
- **Eligibility Status**: Eligible / Not Eligible
|
| 195 |
+
- **Guideline Recommendations**: Specific criteria met and treatment options
|
| 196 |
+
- **Treatment Choices**: Preferred regimens (ETV, TDF, TAF)
|
| 197 |
+
|
| 198 |
+
Example Response:
|
| 199 |
+
```json
|
| 200 |
+
{
|
| 201 |
+
"eligible": true,
|
| 202 |
+
"recommendations": "Patient meets SASLT 2021 criteria: HBV DNA > 2,000 IU/mL, ALT > ULN, and fibrosis stage F2-F3 (Grade A)",
|
| 203 |
+
"treatment_options": ["ETV", "TDF", "TAF"],
|
| 204 |
+
"guideline": "SASLT 2021"
|
| 205 |
+
}
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
## 🔒 Security
|
| 209 |
+
|
| 210 |
+
- Session-based authentication
|
| 211 |
+
- Rate limiting (100 requests/minute)
|
| 212 |
+
- CORS protection
|
| 213 |
+
- Input validation
|
| 214 |
+
- Secure cookie handling
|
| 215 |
+
|
| 216 |
+
## 📝 License
|
| 217 |
+
|
| 218 |
+
[Add your license here]
|
| 219 |
+
|
| 220 |
+
## 🤝 Contributing
|
| 221 |
+
|
| 222 |
+
Contributions are welcome! Please read the contributing guidelines first.
|
| 223 |
+
|
| 224 |
+
## 📧 Support
|
| 225 |
+
|
| 226 |
+
For issues or questions:
|
| 227 |
+
- Check the [DEPLOYMENT.md](DEPLOYMENT.md) guide
|
| 228 |
+
- Review API docs at `/docs`
|
| 229 |
+
- Open an issue on GitHub
|
| 230 |
+
|
| 231 |
+
## 🙏 Acknowledgments
|
| 232 |
+
|
| 233 |
+
Built with:
|
| 234 |
+
- FastAPI
|
| 235 |
+
- Pydantic
|
| 236 |
+
- Python 3.11+
|
| 237 |
+
- SASLT 2021 Guidelines
|
api/__init__.py
ADDED
|
File without changes
|
api/app.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import logging
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from fastapi.exceptions import RequestValidationError
|
| 6 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 7 |
+
|
| 8 |
+
# Import routers
|
| 9 |
+
from api.routers import medical, hbv_assessment
|
| 10 |
+
from api.middleware import (
|
| 11 |
+
ProcessTimeMiddleware,
|
| 12 |
+
LoggingMiddleware,
|
| 13 |
+
RateLimitMiddleware,
|
| 14 |
+
get_cors_middleware_config
|
| 15 |
+
)
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from api.exceptions import (
|
| 18 |
+
http_exception_handler,
|
| 19 |
+
validation_exception_handler,
|
| 20 |
+
general_exception_handler,
|
| 21 |
+
starlette_exception_handler
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Configure logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@asynccontextmanager
|
| 30 |
+
async def lifespan(app: FastAPI):
|
| 31 |
+
"""Application lifespan management with background initialization"""
|
| 32 |
+
# Startup
|
| 33 |
+
logger.info("Starting HBV AI Assistant API...")
|
| 34 |
+
|
| 35 |
+
# Start background initialization of heavy components (optional for AI chat)
|
| 36 |
+
try:
|
| 37 |
+
from core.background_init import start_background_initialization
|
| 38 |
+
logger.info("🚀 Starting background initialization of components...")
|
| 39 |
+
start_background_initialization()
|
| 40 |
+
logger.info("API started successfully (components loading in background)")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to start background initialization: {e}")
|
| 43 |
+
logger.info("API started with lazy loading fallback")
|
| 44 |
+
|
| 45 |
+
yield
|
| 46 |
+
|
| 47 |
+
# Shutdown
|
| 48 |
+
logger.info("Shutting down HBV AI Assistant API...")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Create FastAPI application
|
| 52 |
+
app = FastAPI(
|
| 53 |
+
title="HBV AI Assistant API",
|
| 54 |
+
description="HBV Patient Selection System - Evaluates patient eligibility for HBV treatment according to SASLT 2021 guidelines",
|
| 55 |
+
version="1.0.0",
|
| 56 |
+
docs_url="/docs",
|
| 57 |
+
redoc_url="/redoc",
|
| 58 |
+
lifespan=lifespan
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Add middleware
|
| 62 |
+
app.add_middleware(CORSMiddleware, **get_cors_middleware_config())
|
| 63 |
+
app.add_middleware(ProcessTimeMiddleware)
|
| 64 |
+
app.add_middleware(LoggingMiddleware)
|
| 65 |
+
app.add_middleware(RateLimitMiddleware, calls_per_minute=100) # Adjust as needed
|
| 66 |
+
|
| 67 |
+
# Add exception handlers
|
| 68 |
+
app.add_exception_handler(HTTPException, http_exception_handler)
|
| 69 |
+
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
| 70 |
+
app.add_exception_handler(StarletteHTTPException, starlette_exception_handler)
|
| 71 |
+
app.add_exception_handler(Exception, general_exception_handler)
|
| 72 |
+
|
| 73 |
+
# Include routers
|
| 74 |
+
app.include_router(hbv_assessment.router) # Primary HBV assessment endpoint
|
| 75 |
+
app.include_router(medical.router) # Optional AI chat for guideline exploration
|
| 76 |
+
|
| 77 |
+
# Root endpoint
|
| 78 |
+
@app.get("/")
|
| 79 |
+
async def root():
|
| 80 |
+
"""Root endpoint with API information"""
|
| 81 |
+
return {
|
| 82 |
+
"name": "HBV AI Assistant API",
|
| 83 |
+
"version": "1.0.0",
|
| 84 |
+
"description": "HBV Patient Selection System - Evaluates patient eligibility for HBV treatment according to SASLT 2021 guidelines",
|
| 85 |
+
"docs": "/docs",
|
| 86 |
+
"endpoints": {
|
| 87 |
+
"assess": "/assess (POST) - Primary endpoint for HBV patient eligibility assessment",
|
| 88 |
+
"assess_text": "/assess/text (POST) - Text-based HBV patient eligibility assessment",
|
| 89 |
+
"ask": "/ask (POST) - Optional AI chat for guideline exploration",
|
| 90 |
+
"ask_stream": "/ask/stream (POST) - Streaming AI chat responses",
|
| 91 |
+
"health": "/health (GET) - Simple health check",
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Simple health check endpoint
|
| 97 |
+
@app.get("/health")
|
| 98 |
+
async def health_check():
|
| 99 |
+
"""Simple health check endpoint"""
|
| 100 |
+
from datetime import datetime
|
| 101 |
+
return {
|
| 102 |
+
"status": "healthy",
|
| 103 |
+
"version": "1.0.0",
|
| 104 |
+
"timestamp": datetime.now().isoformat()
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
import uvicorn
|
| 112 |
+
uvicorn.run(
|
| 113 |
+
"api.app:app",
|
| 114 |
+
host="127.0.0.1",
|
| 115 |
+
port=8000,
|
| 116 |
+
reload=True,
|
| 117 |
+
log_level="info"
|
| 118 |
+
)
|
api/exceptions.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Exception handlers for Medical RAG AI Advisor API
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from fastapi import Request, HTTPException
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from fastapi.exceptions import RequestValidationError
|
| 9 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 10 |
+
|
| 11 |
+
from api.models import ErrorResponse
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
| 17 |
+
"""Handle HTTP exceptions"""
|
| 18 |
+
logger.error(f"HTTP Exception: {exc.status_code} - {exc.detail}")
|
| 19 |
+
|
| 20 |
+
return JSONResponse(
|
| 21 |
+
status_code=exc.status_code,
|
| 22 |
+
content=ErrorResponse(
|
| 23 |
+
error="HTTP_ERROR",
|
| 24 |
+
message=exc.detail,
|
| 25 |
+
timestamp=datetime.now().isoformat()
|
| 26 |
+
).dict()
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 31 |
+
"""Handle request validation errors"""
|
| 32 |
+
errors = exc.errors()
|
| 33 |
+
|
| 34 |
+
# Convert errors to JSON-serializable format
|
| 35 |
+
serializable_errors = []
|
| 36 |
+
for error in errors:
|
| 37 |
+
error_dict = {
|
| 38 |
+
"type": error.get("type"),
|
| 39 |
+
"loc": list(error.get("loc", [])),
|
| 40 |
+
"msg": error.get("msg"),
|
| 41 |
+
"input": str(error.get("input")) # Convert to string to ensure serializability
|
| 42 |
+
}
|
| 43 |
+
if "ctx" in error:
|
| 44 |
+
error_dict["ctx"] = {k: str(v) for k, v in error["ctx"].items()}
|
| 45 |
+
serializable_errors.append(error_dict)
|
| 46 |
+
|
| 47 |
+
logger.error(f"Validation Error: {serializable_errors}")
|
| 48 |
+
|
| 49 |
+
return JSONResponse(
|
| 50 |
+
status_code=422,
|
| 51 |
+
content=ErrorResponse(
|
| 52 |
+
error="VALIDATION_ERROR",
|
| 53 |
+
message="Request validation failed",
|
| 54 |
+
details={"validation_errors": serializable_errors},
|
| 55 |
+
timestamp=datetime.now().isoformat()
|
| 56 |
+
).dict()
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 61 |
+
"""Handle general exceptions"""
|
| 62 |
+
logger.error(f"Unhandled Exception: {type(exc).__name__} - {str(exc)}")
|
| 63 |
+
|
| 64 |
+
return JSONResponse(
|
| 65 |
+
status_code=500,
|
| 66 |
+
content=ErrorResponse(
|
| 67 |
+
error="INTERNAL_SERVER_ERROR",
|
| 68 |
+
message="An internal server error occurred",
|
| 69 |
+
details={"exception_type": type(exc).__name__},
|
| 70 |
+
timestamp=datetime.now().isoformat()
|
| 71 |
+
).dict()
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
async def starlette_exception_handler(request: Request, exc: StarletteHTTPException):
|
| 76 |
+
"""Handle Starlette HTTP exceptions"""
|
| 77 |
+
logger.error(f"Starlette HTTP Exception: {exc.status_code} - {exc.detail}")
|
| 78 |
+
|
| 79 |
+
return JSONResponse(
|
| 80 |
+
status_code=exc.status_code,
|
| 81 |
+
content=ErrorResponse(
|
| 82 |
+
error="HTTP_ERROR",
|
| 83 |
+
message=exc.detail,
|
| 84 |
+
timestamp=datetime.now().isoformat()
|
| 85 |
+
).dict()
|
| 86 |
+
)
|
api/middleware.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Middleware for Medical RAG AI Advisor API
|
| 3 |
+
"""
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Callable, Awaitable, Optional
|
| 7 |
+
from fastapi import Request, Response, HTTPException, Cookie
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProcessTimeMiddleware(BaseHTTPMiddleware):
|
| 15 |
+
"""Middleware to add processing time to response headers"""
|
| 16 |
+
|
| 17 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 18 |
+
start_time = time.time()
|
| 19 |
+
response = await call_next(request)
|
| 20 |
+
process_time = time.time() - start_time
|
| 21 |
+
response.headers["X-Process-Time"] = f"{process_time:.4f}"
|
| 22 |
+
return response
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LoggingMiddleware(BaseHTTPMiddleware):
|
| 26 |
+
"""Middleware for request/response logging"""
|
| 27 |
+
|
| 28 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 29 |
+
start_time = time.time()
|
| 30 |
+
|
| 31 |
+
# Log request
|
| 32 |
+
logger.info(f"Request: {request.method} {request.url}")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
response = await call_next(request)
|
| 36 |
+
process_time = time.time() - start_time
|
| 37 |
+
|
| 38 |
+
# Log response
|
| 39 |
+
logger.info(
|
| 40 |
+
f"Response: {response.status_code} - "
|
| 41 |
+
f"Time: {process_time:.4f}s - "
|
| 42 |
+
f"Path: {request.url.path}"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return response
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
process_time = time.time() - start_time
|
| 49 |
+
logger.error(
|
| 50 |
+
f"Error: {str(e)} - "
|
| 51 |
+
f"Time: {process_time:.4f}s - "
|
| 52 |
+
f"Path: {request.url.path}"
|
| 53 |
+
)
|
| 54 |
+
raise
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 58 |
+
"""Simple rate limiting middleware"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, app, calls_per_minute: int = 60):
|
| 61 |
+
super().__init__(app)
|
| 62 |
+
self.calls_per_minute = calls_per_minute
|
| 63 |
+
self.client_calls = {}
|
| 64 |
+
|
| 65 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 66 |
+
client_ip = request.client.host
|
| 67 |
+
current_time = time.time()
|
| 68 |
+
|
| 69 |
+
# Clean old entries
|
| 70 |
+
self.client_calls = {
|
| 71 |
+
ip: calls for ip, calls in self.client_calls.items()
|
| 72 |
+
if any(call_time > current_time - 60 for call_time in calls)
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Check rate limit
|
| 76 |
+
if client_ip in self.client_calls:
|
| 77 |
+
recent_calls = [
|
| 78 |
+
call_time for call_time in self.client_calls[client_ip]
|
| 79 |
+
if call_time > current_time - 60
|
| 80 |
+
]
|
| 81 |
+
if len(recent_calls) >= self.calls_per_minute:
|
| 82 |
+
raise HTTPException(
|
| 83 |
+
status_code=429,
|
| 84 |
+
detail="Rate limit exceeded. Please try again later."
|
| 85 |
+
)
|
| 86 |
+
self.client_calls[client_ip] = recent_calls + [current_time]
|
| 87 |
+
else:
|
| 88 |
+
self.client_calls[client_ip] = [current_time]
|
| 89 |
+
|
| 90 |
+
return await call_next(request)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
| 94 |
+
"""Middleware to protect endpoints with session authentication"""
|
| 95 |
+
|
| 96 |
+
# Paths that don't require authentication
|
| 97 |
+
PUBLIC_PATHS = [
|
| 98 |
+
"/",
|
| 99 |
+
"/docs",
|
| 100 |
+
"/redoc",
|
| 101 |
+
"/openapi.json",
|
| 102 |
+
"/health",
|
| 103 |
+
"/auth/login",
|
| 104 |
+
"/auth/status",
|
| 105 |
+
"/assess", # Allow assess endpoint for local testing
|
| 106 |
+
"/ask", # Allow ask endpoint for local testing
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 110 |
+
# For local testing, disable authentication
|
| 111 |
+
# TODO: Enable authentication in production
|
| 112 |
+
return await call_next(request)
|
| 113 |
+
|
| 114 |
+
# Original authentication code (disabled for local testing)
|
| 115 |
+
# # Check if path is public
|
| 116 |
+
# path = request.url.path
|
| 117 |
+
#
|
| 118 |
+
# # Allow public paths
|
| 119 |
+
# if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS):
|
| 120 |
+
# return await call_next(request)
|
| 121 |
+
#
|
| 122 |
+
# # Check for session token
|
| 123 |
+
# session_token = request.cookies.get("session_token")
|
| 124 |
+
#
|
| 125 |
+
# if not session_token:
|
| 126 |
+
# raise HTTPException(
|
| 127 |
+
# status_code=401,
|
| 128 |
+
# detail="Authentication required"
|
| 129 |
+
# )
|
| 130 |
+
#
|
| 131 |
+
# # Verify session
|
| 132 |
+
# from api.routers.auth import verify_session
|
| 133 |
+
# session_data = verify_session(session_token)
|
| 134 |
+
#
|
| 135 |
+
# if not session_data:
|
| 136 |
+
# raise HTTPException(
|
| 137 |
+
# status_code=401,
|
| 138 |
+
# detail="Invalid or expired session"
|
| 139 |
+
# )
|
| 140 |
+
#
|
| 141 |
+
# # Add user info to request state
|
| 142 |
+
# request.state.user = session_data.get("username")
|
| 143 |
+
#
|
| 144 |
+
# return await call_next(request)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_cors_middleware_config():
|
| 148 |
+
"""Get CORS middleware configuration"""
|
| 149 |
+
import os
|
| 150 |
+
|
| 151 |
+
# Get allowed origins from environment or use defaults
|
| 152 |
+
allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",")
|
| 153 |
+
if not allowed_origins or allowed_origins == [""]:
|
| 154 |
+
# Default to allowing Hugging Face Space and localhost
|
| 155 |
+
# Include null for file:// protocol and common local development origins
|
| 156 |
+
allowed_origins = [
|
| 157 |
+
"http://127.0.0.1:7860",
|
| 158 |
+
"https://huggingface.co",
|
| 159 |
+
"http://localhost:8000",
|
| 160 |
+
"http://127.0.0.1:8000",
|
| 161 |
+
"http://localhost:8080", # Frontend server
|
| 162 |
+
"http://127.0.0.1:8080",
|
| 163 |
+
"http://localhost:5500", # Live Server default port
|
| 164 |
+
"http://127.0.0.1:5500",
|
| 165 |
+
"http://localhost:3000", # Common dev server port
|
| 166 |
+
"http://127.0.0.1:3000",
|
| 167 |
+
"null", # For file:// protocol
|
| 168 |
+
"*" # Allow all origins for local testing
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"allow_origins": ["*"], # Allow all origins for local testing
|
| 173 |
+
"allow_credentials": False, # Must be False when allow_origins is "*"
|
| 174 |
+
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
| 175 |
+
"allow_headers": ["*"], # Allow all headers
|
| 176 |
+
"expose_headers": ["Set-Cookie"],
|
| 177 |
+
}
|
api/models.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Models and Schemas for HBV AI Assistant
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel, Field, validator
|
| 5 |
+
from typing import Optional, List, Dict, Any
|
| 6 |
+
from enum import Enum
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class QueryType(str, Enum):
|
| 10 |
+
"""Types of medical queries supported"""
|
| 11 |
+
GENERAL = "general"
|
| 12 |
+
DRUG_INTERACTION = "drug_interaction"
|
| 13 |
+
SIDE_EFFECTS = "side_effects"
|
| 14 |
+
GUIDELINES = "guidelines"
|
| 15 |
+
COMPARISON = "comparison"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class QueryRequest(BaseModel):
|
| 19 |
+
"""Request model for medical queries"""
|
| 20 |
+
query: str = Field(..., description="Medical question or query", min_length=1)
|
| 21 |
+
query_type: Optional[QueryType] = Field(None, description="Type of medical query")
|
| 22 |
+
context: Optional[str] = Field(None, description="Additional context for the query")
|
| 23 |
+
patient_info: Optional[Dict[str, Any]] = Field(None, description="Patient information if relevant")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class QueryResponse(BaseModel):
|
| 27 |
+
"""Response model for medical queries"""
|
| 28 |
+
response: str = Field(..., description="AI-generated medical response")
|
| 29 |
+
sources: Optional[List[str]] = Field(None, description="Sources used for the response")
|
| 30 |
+
confidence: Optional[float] = Field(None, description="Confidence score of the response")
|
| 31 |
+
query_type: Optional[QueryType] = Field(None, description="Detected or specified query type")
|
| 32 |
+
processing_time: Optional[float] = Field(None, description="Time taken to process the query")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class StreamChunk(BaseModel):
|
| 36 |
+
"""Model for streaming response chunks"""
|
| 37 |
+
chunk: str = Field(..., description="Chunk of the streaming response")
|
| 38 |
+
is_final: bool = Field(False, description="Whether this is the final chunk")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SideEffectReport(BaseModel):
|
| 42 |
+
"""Model for side effect reporting"""
|
| 43 |
+
drug_name: str = Field(..., description="Name of the drug")
|
| 44 |
+
side_effects: str = Field(..., description="Reported side effects")
|
| 45 |
+
patient_age: Optional[int] = Field(None, description="Patient age")
|
| 46 |
+
patient_gender: Optional[str] = Field(None, description="Patient gender")
|
| 47 |
+
dosage: Optional[str] = Field(None, description="Drug dosage")
|
| 48 |
+
duration: Optional[str] = Field(None, description="Duration of treatment")
|
| 49 |
+
severity: Optional[str] = Field(None, description="Severity of side effects")
|
| 50 |
+
outcome: Optional[str] = Field(None, description="Outcome of the side effects")
|
| 51 |
+
additional_details: Optional[str] = Field(None, description="Additional clinical details")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SideEffectResponse(BaseModel):
|
| 55 |
+
"""Response model for side effect reporting"""
|
| 56 |
+
report_id: str = Field(..., description="Unique identifier for the report")
|
| 57 |
+
status: str = Field(..., description="Status of the report submission")
|
| 58 |
+
message: str = Field(..., description="Confirmation message")
|
| 59 |
+
recommendations: Optional[List[str]] = Field(None, description="Clinical recommendations")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ComparisonRequest(BaseModel):
|
| 63 |
+
"""Request model for provider/treatment comparisons"""
|
| 64 |
+
providers: List[str] = Field(..., description="List of providers to compare", min_items=2)
|
| 65 |
+
criteria: Optional[List[str]] = Field(None, description="Specific criteria for comparison")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ComparisonResponse(BaseModel):
|
| 69 |
+
"""Response model for provider/treatment comparisons"""
|
| 70 |
+
comparison: str = Field(..., description="Detailed comparison analysis")
|
| 71 |
+
summary: Dict[str, Any] = Field(..., description="Summary of key differences")
|
| 72 |
+
recommendations: Optional[str] = Field(None, description="Recommendations based on comparison")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class InitializationStatus(BaseModel):
|
| 76 |
+
"""Initialization status response model"""
|
| 77 |
+
is_complete: bool = Field(..., description="Whether initialization is complete")
|
| 78 |
+
status_message: str = Field(..., description="Current initialization status")
|
| 79 |
+
is_successful: bool = Field(..., description="Whether initialization was successful")
|
| 80 |
+
error: Optional[str] = Field(None, description="Initialization error if any")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class HealthStatus(BaseModel):
|
| 84 |
+
"""Health check response model"""
|
| 85 |
+
status: str = Field(..., description="API health status")
|
| 86 |
+
version: str = Field(..., description="API version")
|
| 87 |
+
timestamp: str = Field(..., description="Current timestamp")
|
| 88 |
+
components: Dict[str, str] = Field(..., description="Status of system components")
|
| 89 |
+
initialization: Optional[InitializationStatus] = Field(None, description="Background initialization status")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ErrorResponse(BaseModel):
|
| 93 |
+
"""Error response model"""
|
| 94 |
+
error: str = Field(..., description="Error type")
|
| 95 |
+
message: str = Field(..., description="Error message")
|
| 96 |
+
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
|
| 97 |
+
timestamp: str = Field(..., description="Error timestamp")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ============================================================================
|
| 101 |
+
# HBV PATIENT ASSESSMENT MODELS
|
| 102 |
+
# ============================================================================
|
| 103 |
+
|
| 104 |
+
class HBVPatientInput(BaseModel):
|
| 105 |
+
"""Input model for HBV patient assessment"""
|
| 106 |
+
sex: str = Field(..., description="Patient sex: Male / Female")
|
| 107 |
+
age: int = Field(..., description="Patient age in years", ge=0, le=120)
|
| 108 |
+
pregnancy_status: str = Field(..., description="Pregnancy status: Not pregnant / Pregnant")
|
| 109 |
+
hbsag_status: str = Field(..., description="HBsAg status: Positive / Negative")
|
| 110 |
+
duration_hbsag_months: int = Field(..., description="Duration of HBsAg positivity in months", ge=0)
|
| 111 |
+
hbv_dna_level: float = Field(..., description="HBV DNA level in IU/mL", ge=0)
|
| 112 |
+
hbeag_status: str = Field(..., description="HBeAg status: Positive / Negative")
|
| 113 |
+
alt_level: float = Field(..., description="ALT level in U/L", ge=0)
|
| 114 |
+
fibrosis_stage: str = Field(..., description="Fibrosis/Cirrhosis stage: F0-F1 / F2-F3 / F4")
|
| 115 |
+
necroinflammatory_activity: str = Field(..., description="Necroinflammatory activity: A0 / A1 / A2 / A3")
|
| 116 |
+
extrahepatic_manifestations: bool = Field(..., description="Presence of extrahepatic manifestations")
|
| 117 |
+
immunosuppression_status: Optional[str] = Field(None, description="Immunosuppression status: None / Chemotherapy / Other")
|
| 118 |
+
coinfections: List[str] = Field(default_factory=list, description="Coinfections: HIV, HCV, HDV")
|
| 119 |
+
family_history_cirrhosis_hcc: bool = Field(..., description="Family history of Cirrhosis or HCC (first-degree relative)")
|
| 120 |
+
other_comorbidities: Optional[List[str]] = Field(None, description="Other comorbidities")
|
| 121 |
+
|
| 122 |
+
@validator('sex')
|
| 123 |
+
def validate_sex(cls, v):
|
| 124 |
+
if v not in ['Male', 'Female']:
|
| 125 |
+
raise ValueError('Sex must be either Male or Female')
|
| 126 |
+
return v
|
| 127 |
+
|
| 128 |
+
@validator('pregnancy_status')
|
| 129 |
+
def validate_pregnancy(cls, v):
|
| 130 |
+
if v not in ['Not pregnant', 'Pregnant']:
|
| 131 |
+
raise ValueError('Pregnancy status must be either "Not pregnant" or "Pregnant"')
|
| 132 |
+
return v
|
| 133 |
+
|
| 134 |
+
@validator('hbsag_status')
|
| 135 |
+
def validate_hbsag(cls, v):
|
| 136 |
+
if v not in ['Positive', 'Negative']:
|
| 137 |
+
raise ValueError('HBsAg status must be either Positive or Negative')
|
| 138 |
+
return v
|
| 139 |
+
|
| 140 |
+
@validator('hbeag_status')
|
| 141 |
+
def validate_hbeag(cls, v):
|
| 142 |
+
if v not in ['Positive', 'Negative']:
|
| 143 |
+
raise ValueError('HBeAg status must be either Positive or Negative')
|
| 144 |
+
return v
|
| 145 |
+
|
| 146 |
+
@validator('fibrosis_stage')
|
| 147 |
+
def validate_fibrosis(cls, v):
|
| 148 |
+
if v not in ['F0-F1', 'F2-F3', 'F4']:
|
| 149 |
+
raise ValueError('Fibrosis stage must be F0-F1, F2-F3, or F4')
|
| 150 |
+
return v
|
| 151 |
+
|
| 152 |
+
@validator('necroinflammatory_activity')
|
| 153 |
+
def validate_necroinflammatory(cls, v):
|
| 154 |
+
if v not in ['A0', 'A1', 'A2', 'A3']:
|
| 155 |
+
raise ValueError('Necroinflammatory activity must be A0, A1, A2, or A3')
|
| 156 |
+
return v
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class HBVAssessmentResponse(BaseModel):
|
| 160 |
+
"""Response model for HBV patient assessment"""
|
| 161 |
+
eligible: bool = Field(..., description="Whether patient is eligible for treatment")
|
| 162 |
+
recommendations: str = Field(..., description="Detailed recommendations with citations from SASLT 2021 guidelines including page numbers")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class TextAssessmentInput(BaseModel):
|
| 166 |
+
"""Input model for text-based HBV patient assessment"""
|
| 167 |
+
text_input: str = Field(..., description="Free-form text containing patient data", min_length=10)
|
| 168 |
+
|
| 169 |
+
class Config:
|
| 170 |
+
json_schema_extra = {
|
| 171 |
+
"example": {
|
| 172 |
+
"text_input": "45-year-old male patient\\nHBsAg: Positive for 12 months\\nHBV DNA: 5000 IU/mL\\nHBeAg: Positive\\nALT: 80 U/L\\nFibrosis stage: F2-F3\\nNecroinflammatory activity: A2\\nNo extrahepatic manifestations\\nNo immunosuppression\\nNo coinfections (HIV, HCV, HDV)\\nNo family history of cirrhosis or HCC"
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ChatRequest(BaseModel):
|
| 178 |
+
"""Request model for chat interactions"""
|
| 179 |
+
query: str = Field(..., description="Doctor's question about HBV guidelines", min_length=1)
|
| 180 |
+
session_id: str = Field(default="default", description="Session identifier for conversation continuity")
|
| 181 |
+
patient_context: Optional[HBVPatientInput] = Field(None, description="Optional patient context from assessment")
|
| 182 |
+
assessment_result: Optional[HBVAssessmentResponse] = Field(None, description="Optional assessment result for context")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class ChatResponse(BaseModel):
|
| 186 |
+
"""Response model for chat interactions"""
|
| 187 |
+
response: str = Field(..., description="AI response to the doctor's question")
|
| 188 |
+
session_id: str = Field(..., description="Session identifier")
|
api/routers/__init__.py
ADDED
|
File without changes
|
api/routers/auth.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication router for simple login system
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import secrets
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
from fastapi import APIRouter, HTTPException, Response, Cookie, Form, Request
|
| 9 |
+
from fastapi.responses import JSONResponse
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
| 12 |
+
import logging
|
| 13 |
+
from urllib.parse import urlparse
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
| 18 |
+
|
| 19 |
+
# Session management
|
| 20 |
+
SESSION_SECRET_KEY = os.getenv("SESSION_SECRET_KEY", secrets.token_hex(32))
|
| 21 |
+
SESSION_MAX_AGE = 86400 # 24 hours in seconds
|
| 22 |
+
serializer = URLSafeTimedSerializer(SESSION_SECRET_KEY)
|
| 23 |
+
|
| 24 |
+
# In-memory session store (for simple use case)
|
| 25 |
+
# For production, consider using Redis or database
|
| 26 |
+
active_sessions: Dict[str, dict] = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LoginRequest(BaseModel):
|
| 30 |
+
username: str
|
| 31 |
+
password: str
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LoginResponse(BaseModel):
|
| 35 |
+
success: bool
|
| 36 |
+
message: str
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def create_session(username: str) -> str:
|
| 40 |
+
"""Create a new session token"""
|
| 41 |
+
session_id = secrets.token_urlsafe(32)
|
| 42 |
+
session_data = {
|
| 43 |
+
"username": username,
|
| 44 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 45 |
+
"expires_at": (datetime.utcnow() + timedelta(seconds=SESSION_MAX_AGE)).isoformat()
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Store session
|
| 49 |
+
active_sessions[session_id] = session_data
|
| 50 |
+
|
| 51 |
+
# Create signed token
|
| 52 |
+
token = serializer.dumps(session_id)
|
| 53 |
+
return token
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def verify_session(token: Optional[str]) -> Optional[dict]:
|
| 57 |
+
"""Verify session token and return session data"""
|
| 58 |
+
if not token:
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# Verify signature and age
|
| 63 |
+
session_id = serializer.loads(token, max_age=SESSION_MAX_AGE)
|
| 64 |
+
|
| 65 |
+
# Check if session exists
|
| 66 |
+
session_data = active_sessions.get(session_id)
|
| 67 |
+
if not session_data:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
# Check expiration
|
| 71 |
+
expires_at = datetime.fromisoformat(session_data["expires_at"])
|
| 72 |
+
if datetime.utcnow() > expires_at:
|
| 73 |
+
# Clean up expired session
|
| 74 |
+
active_sessions.pop(session_id, None)
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
return session_data
|
| 78 |
+
except (BadSignature, SignatureExpired):
|
| 79 |
+
return None
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Session verification error: {e}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def verify_credentials(username: str, password: str) -> bool:
|
| 86 |
+
"""Verify username and password against environment variables"""
|
| 87 |
+
expected_username = "volaris"
|
| 88 |
+
expected_password = "volaris123"
|
| 89 |
+
|
| 90 |
+
return username == expected_username and password == expected_password
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@router.post("/login", response_model=LoginResponse)
|
| 94 |
+
async def login(
|
| 95 |
+
response: Response,
|
| 96 |
+
request: Request,
|
| 97 |
+
username: str = Form(...),
|
| 98 |
+
password: str = Form(...)
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Login endpoint - validates credentials and creates session
|
| 102 |
+
"""
|
| 103 |
+
# Log login attempt
|
| 104 |
+
logger.info(f"Login attempt for username: {username}, Origin: {request.headers.get('origin')}")
|
| 105 |
+
|
| 106 |
+
# Verify credentials
|
| 107 |
+
if not verify_credentials(username, password):
|
| 108 |
+
logger.warning(f"Failed login attempt for username: {username}")
|
| 109 |
+
raise HTTPException(status_code=401, detail="Invalid username or password")
|
| 110 |
+
|
| 111 |
+
# Create session
|
| 112 |
+
token = create_session(username)
|
| 113 |
+
logger.info(f"Session created for user: {username}")
|
| 114 |
+
|
| 115 |
+
# Set secure cookie
|
| 116 |
+
# Detect if we're running on HTTPS (Hugging Face Spaces use HTTPS)
|
| 117 |
+
is_https = request.url.scheme == "https" or request.headers.get("x-forwarded-proto") == "https"
|
| 118 |
+
|
| 119 |
+
# For HTTPS (production/HF Spaces), use SameSite=None with Secure=True for cross-origin
|
| 120 |
+
# For HTTP (local dev), use SameSite=Lax with Secure=False
|
| 121 |
+
if is_https:
|
| 122 |
+
samesite = "none"
|
| 123 |
+
secure = True
|
| 124 |
+
else:
|
| 125 |
+
samesite = "lax"
|
| 126 |
+
secure = False
|
| 127 |
+
|
| 128 |
+
logger.info(f"Setting cookie with samesite={samesite}, secure={secure}, is_https={is_https}")
|
| 129 |
+
|
| 130 |
+
response.set_cookie(
|
| 131 |
+
key="session_token",
|
| 132 |
+
value=token,
|
| 133 |
+
httponly=True,
|
| 134 |
+
max_age=SESSION_MAX_AGE,
|
| 135 |
+
samesite=samesite,
|
| 136 |
+
secure=secure,
|
| 137 |
+
path="/"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
logger.info(f"Successful login for user: {username}")
|
| 141 |
+
|
| 142 |
+
return LoginResponse(
|
| 143 |
+
success=True,
|
| 144 |
+
message="Login successful"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@router.post("/logout")
|
| 149 |
+
async def logout(
|
| 150 |
+
response: Response,
|
| 151 |
+
session_token: Optional[str] = Cookie(None)
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
Logout endpoint - invalidates session
|
| 155 |
+
"""
|
| 156 |
+
if session_token:
|
| 157 |
+
try:
|
| 158 |
+
session_id = serializer.loads(session_token, max_age=SESSION_MAX_AGE)
|
| 159 |
+
active_sessions.pop(session_id, None)
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
# Clear cookie
|
| 164 |
+
response.delete_cookie(key="session_token")
|
| 165 |
+
|
| 166 |
+
return {"success": True, "message": "Logged out successfully"}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@router.get("/verify")
|
| 170 |
+
async def verify(session_token: Optional[str] = Cookie(None)):
|
| 171 |
+
"""
|
| 172 |
+
Verify if current session is valid
|
| 173 |
+
"""
|
| 174 |
+
session_data = verify_session(session_token)
|
| 175 |
+
|
| 176 |
+
if not session_data:
|
| 177 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
"authenticated": True,
|
| 181 |
+
"username": session_data.get("username")
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@router.get("/status")
|
| 186 |
+
async def status(request: Request, session_token: Optional[str] = Cookie(None)):
|
| 187 |
+
"""
|
| 188 |
+
Check authentication status without raising exception
|
| 189 |
+
"""
|
| 190 |
+
logger.info(f"Status check - Cookie present: {session_token is not None}, Origin: {request.headers.get('origin')}")
|
| 191 |
+
session_data = verify_session(session_token)
|
| 192 |
+
|
| 193 |
+
if session_data:
|
| 194 |
+
logger.info(f"Status check - Authenticated as: {session_data.get('username')}")
|
| 195 |
+
else:
|
| 196 |
+
logger.info("Status check - Not authenticated")
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"authenticated": session_data is not None,
|
| 200 |
+
"username": session_data.get("username") if session_data else None
|
| 201 |
+
}
|
api/routers/hbv_assessment.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HBV Patient Assessment Router
|
| 3 |
+
API endpoint for HBV treatment eligibility assessment
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import APIRouter, HTTPException
|
| 6 |
+
from api.models import HBVPatientInput, HBVAssessmentResponse, TextAssessmentInput
|
| 7 |
+
import logging
|
| 8 |
+
from core.hbv_assessment import assess_hbv_eligibility
|
| 9 |
+
from core.text_parser import parse_patient_text, validate_extracted_data
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
router = APIRouter(
|
| 14 |
+
prefix="",
|
| 15 |
+
tags=["HBV Assessment"]
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@router.post("/assess", response_model=HBVAssessmentResponse)
|
| 20 |
+
async def assess_patient(patient: HBVPatientInput) -> HBVAssessmentResponse:
|
| 21 |
+
"""
|
| 22 |
+
Assess HBV patient eligibility for treatment according to SASLT 2021 guidelines
|
| 23 |
+
|
| 24 |
+
This endpoint:
|
| 25 |
+
1. Validates patient data
|
| 26 |
+
2. Creates intelligent search query based on patient parameters
|
| 27 |
+
3. Retrieves relevant SASLT 2021 guidelines from vector store
|
| 28 |
+
4. Uses LLM to analyze patient against retrieved guidelines
|
| 29 |
+
5. Returns structured assessment with eligibility and comprehensive recommendations
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
HBVAssessmentResponse containing:
|
| 33 |
+
- eligible: Whether patient is eligible for treatment
|
| 34 |
+
- recommendations: Comprehensive narrative including eligibility determination,
|
| 35 |
+
specific criteria met, treatment options (ETV, TDF, TAF), and special considerations,
|
| 36 |
+
with inline citations in format [SASLT 2021, Page X]
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
logger.info(f"Assessing HBV patient: Age {patient.age}, Sex {patient.sex}, HBV DNA {patient.hbv_dna_level}")
|
| 40 |
+
|
| 41 |
+
# Convert Pydantic model to dict for core function
|
| 42 |
+
patient_data = patient.dict()
|
| 43 |
+
|
| 44 |
+
# Call core assessment function
|
| 45 |
+
result = assess_hbv_eligibility(patient_data)
|
| 46 |
+
|
| 47 |
+
# Convert dict result back to Pydantic response model
|
| 48 |
+
response = HBVAssessmentResponse(**result)
|
| 49 |
+
|
| 50 |
+
logger.info(f"Assessment complete: Eligible={response.eligible}")
|
| 51 |
+
return response
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Error assessing patient: {str(e)}")
|
| 55 |
+
raise HTTPException(status_code=500, detail=f"Error assessing patient: {str(e)}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@router.post("/assess/text", response_model=HBVAssessmentResponse)
|
| 59 |
+
async def assess_patient_from_text(text_input: TextAssessmentInput) -> HBVAssessmentResponse:
|
| 60 |
+
"""
|
| 61 |
+
Assess HBV patient eligibility from free-form text input
|
| 62 |
+
|
| 63 |
+
This endpoint:
|
| 64 |
+
1. Parses free-form text to extract structured patient data using LLM
|
| 65 |
+
2. Validates the extracted data
|
| 66 |
+
3. Performs the same assessment as /assess endpoint
|
| 67 |
+
4. Returns structured assessment with eligibility and recommendations
|
| 68 |
+
|
| 69 |
+
Example text input:
|
| 70 |
+
"45-year-old male patient
|
| 71 |
+
HBsAg: Positive for 12 months
|
| 72 |
+
HBV DNA: 5000 IU/mL
|
| 73 |
+
HBeAg: Positive
|
| 74 |
+
ALT: 80 U/L
|
| 75 |
+
Fibrosis stage: F2-F3
|
| 76 |
+
Necroinflammatory activity: A2
|
| 77 |
+
No extrahepatic manifestations
|
| 78 |
+
No immunosuppression
|
| 79 |
+
No coinfections (HIV, HCV, HDV)
|
| 80 |
+
No family history of cirrhosis or HCC"
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
HBVAssessmentResponse containing:
|
| 84 |
+
- eligible: Whether patient is eligible for treatment
|
| 85 |
+
- recommendations: Comprehensive narrative with inline citations
|
| 86 |
+
"""
|
| 87 |
+
try:
|
| 88 |
+
logger.info(f"Received text input for assessment (length: {len(text_input.text_input)} characters)")
|
| 89 |
+
logger.info(f"Text input preview: {text_input.text_input[:200]}...")
|
| 90 |
+
|
| 91 |
+
# Parse text to extract structured patient data
|
| 92 |
+
logger.info("Parsing text input to extract patient data...")
|
| 93 |
+
patient_data = parse_patient_text(text_input.text_input)
|
| 94 |
+
logger.info(f"Extracted patient data: {patient_data}")
|
| 95 |
+
|
| 96 |
+
# Validate extracted data
|
| 97 |
+
logger.info("Validating extracted patient data...")
|
| 98 |
+
validated_data = validate_extracted_data(patient_data)
|
| 99 |
+
logger.info("Patient data validated successfully")
|
| 100 |
+
|
| 101 |
+
# Call core assessment function with extracted data
|
| 102 |
+
logger.info("Performing HBV eligibility assessment...")
|
| 103 |
+
result = assess_hbv_eligibility(validated_data)
|
| 104 |
+
|
| 105 |
+
# Convert dict result to Pydantic response model
|
| 106 |
+
response = HBVAssessmentResponse(**result)
|
| 107 |
+
|
| 108 |
+
logger.info(f"Text-based assessment complete: Eligible={response.eligible}")
|
| 109 |
+
return response
|
| 110 |
+
|
| 111 |
+
except ValueError as e:
|
| 112 |
+
logger.error(f"Validation error in text assessment: {str(e)}")
|
| 113 |
+
raise HTTPException(status_code=400, detail=f"Invalid patient data: {str(e)}")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
logger.error(f"Error in text-based assessment: {str(e)}")
|
| 116 |
+
raise HTTPException(status_code=500, detail=f"Error processing text input: {str(e)}")
|
api/routers/health.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Health Check and System Status Router
|
| 3 |
+
"""
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add src to path for imports
|
| 10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 11 |
+
|
| 12 |
+
from api.models import HealthStatus, InitializationStatus
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/health", tags=["health"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/", response_model=HealthStatus)
|
| 18 |
+
async def health_check():
|
| 19 |
+
"""
|
| 20 |
+
Check the health status of the API and its components
|
| 21 |
+
"""
|
| 22 |
+
components = {}
|
| 23 |
+
|
| 24 |
+
# Check agent availability
|
| 25 |
+
try:
|
| 26 |
+
from agent import safe_run_agent
|
| 27 |
+
components["agent"] = "healthy"
|
| 28 |
+
except Exception:
|
| 29 |
+
components["agent"] = "unhealthy"
|
| 30 |
+
|
| 31 |
+
# Check vector store
|
| 32 |
+
try:
|
| 33 |
+
from vector_store import VectorStore
|
| 34 |
+
components["vector_store"] = "healthy"
|
| 35 |
+
except Exception:
|
| 36 |
+
components["vector_store"] = "unhealthy"
|
| 37 |
+
|
| 38 |
+
# Check data loaders
|
| 39 |
+
try:
|
| 40 |
+
from data_loaders import load_pdf_documents
|
| 41 |
+
components["data_loaders"] = "healthy"
|
| 42 |
+
except Exception:
|
| 43 |
+
components["data_loaders"] = "unhealthy"
|
| 44 |
+
|
| 45 |
+
# Check tools
|
| 46 |
+
try:
|
| 47 |
+
from tools import medical_guidelines_knowledge_tool
|
| 48 |
+
components["tools"] = "healthy"
|
| 49 |
+
except Exception:
|
| 50 |
+
components["tools"] = "unhealthy"
|
| 51 |
+
|
| 52 |
+
# Check initialization status
|
| 53 |
+
initialization_status = None
|
| 54 |
+
try:
|
| 55 |
+
from background_init import (
|
| 56 |
+
is_initialization_complete,
|
| 57 |
+
get_initialization_status,
|
| 58 |
+
is_initialization_successful,
|
| 59 |
+
get_initialization_error
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
initialization_status = InitializationStatus(
|
| 63 |
+
is_complete=is_initialization_complete(),
|
| 64 |
+
status_message=get_initialization_status(),
|
| 65 |
+
is_successful=is_initialization_successful(),
|
| 66 |
+
error=str(get_initialization_error()) if get_initialization_error() else None
|
| 67 |
+
)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
initialization_status = InitializationStatus(
|
| 70 |
+
is_complete=False,
|
| 71 |
+
status_message=f"Unable to check initialization status: {str(e)}",
|
| 72 |
+
is_successful=False,
|
| 73 |
+
error=str(e)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Overall status
|
| 77 |
+
overall_status = "healthy" if all(
|
| 78 |
+
status == "healthy" for status in components.values()
|
| 79 |
+
) else "degraded"
|
| 80 |
+
|
| 81 |
+
return HealthStatus(
|
| 82 |
+
status=overall_status,
|
| 83 |
+
version="1.0.0",
|
| 84 |
+
timestamp=datetime.now().isoformat(),
|
| 85 |
+
components=components,
|
| 86 |
+
initialization=initialization_status
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.get("/ping")
|
| 91 |
+
async def ping():
|
| 92 |
+
"""
|
| 93 |
+
Simple ping endpoint for basic connectivity check
|
| 94 |
+
"""
|
| 95 |
+
return {"message": "pong", "timestamp": datetime.now().isoformat()}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@router.get("/initialization", response_model=InitializationStatus)
|
| 99 |
+
async def get_initialization_status():
|
| 100 |
+
"""
|
| 101 |
+
Get the current initialization status of background components
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
from background_init import (
|
| 105 |
+
is_initialization_complete,
|
| 106 |
+
get_initialization_status,
|
| 107 |
+
is_initialization_successful,
|
| 108 |
+
get_initialization_error
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return InitializationStatus(
|
| 112 |
+
is_complete=is_initialization_complete(),
|
| 113 |
+
status_message=get_initialization_status(),
|
| 114 |
+
is_successful=is_initialization_successful(),
|
| 115 |
+
error=str(get_initialization_error()) if get_initialization_error() else None
|
| 116 |
+
)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
return InitializationStatus(
|
| 119 |
+
is_complete=False,
|
| 120 |
+
status_message=f"Unable to check initialization status: {str(e)}",
|
| 121 |
+
is_successful=False,
|
| 122 |
+
error=str(e)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@router.get("/version")
|
| 127 |
+
async def get_version():
|
| 128 |
+
"""
|
| 129 |
+
Get API version information
|
| 130 |
+
"""
|
| 131 |
+
return {
|
| 132 |
+
"version": "1.0.0",
|
| 133 |
+
"name": "Medical RAG AI Advisor API",
|
| 134 |
+
"description": "Professional API for medical information retrieval and advisory services",
|
| 135 |
+
"build_date": "2024-01-01"
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@router.get("/sessions")
|
| 140 |
+
async def get_active_sessions():
|
| 141 |
+
"""
|
| 142 |
+
Get list of all active conversation sessions
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
from core.agent import get_active_sessions
|
| 146 |
+
sessions = get_active_sessions()
|
| 147 |
+
return {
|
| 148 |
+
"active_sessions": sessions,
|
| 149 |
+
"count": len(sessions),
|
| 150 |
+
"timestamp": datetime.now().isoformat()
|
| 151 |
+
}
|
| 152 |
+
except Exception as e:
|
| 153 |
+
return {
|
| 154 |
+
"error": str(e),
|
| 155 |
+
"active_sessions": [],
|
| 156 |
+
"count": 0
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@router.delete("/sessions/{session_id}")
|
| 161 |
+
async def clear_session(session_id: str):
|
| 162 |
+
"""
|
| 163 |
+
Clear conversation memory for a specific session
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
session_id: The session identifier to clear
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
from core.agent import clear_session_memory
|
| 170 |
+
success = clear_session_memory(session_id)
|
| 171 |
+
if success:
|
| 172 |
+
return {
|
| 173 |
+
"message": f"Session '{session_id}' cleared successfully",
|
| 174 |
+
"session_id": session_id,
|
| 175 |
+
"timestamp": datetime.now().isoformat()
|
| 176 |
+
}
|
| 177 |
+
else:
|
| 178 |
+
return {
|
| 179 |
+
"message": f"Session '{session_id}' not found",
|
| 180 |
+
"session_id": session_id,
|
| 181 |
+
"timestamp": datetime.now().isoformat()
|
| 182 |
+
}
|
| 183 |
+
except Exception as e:
|
| 184 |
+
return {
|
| 185 |
+
"error": str(e),
|
| 186 |
+
"session_id": session_id
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@router.delete("/sessions")
|
| 191 |
+
async def clear_all_sessions():
|
| 192 |
+
"""
|
| 193 |
+
Clear all conversation sessions
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
from core.agent import clear_memory
|
| 197 |
+
clear_memory()
|
| 198 |
+
return {
|
| 199 |
+
"message": "All sessions cleared successfully",
|
| 200 |
+
"timestamp": datetime.now().isoformat()
|
| 201 |
+
}
|
| 202 |
+
except Exception as e:
|
| 203 |
+
return {
|
| 204 |
+
"error": str(e)
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@router.get("/sessions/{session_id}/summary")
|
| 209 |
+
async def get_session_summary(session_id: str):
|
| 210 |
+
"""
|
| 211 |
+
Get conversation history summary for a specific session
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
session_id: The session identifier
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
from core.agent import get_memory_summary
|
| 218 |
+
summary = get_memory_summary(session_id)
|
| 219 |
+
return {
|
| 220 |
+
"session_id": session_id,
|
| 221 |
+
"summary": summary,
|
| 222 |
+
"timestamp": datetime.now().isoformat()
|
| 223 |
+
}
|
| 224 |
+
except Exception as e:
|
| 225 |
+
return {
|
| 226 |
+
"error": str(e),
|
| 227 |
+
"session_id": session_id
|
| 228 |
+
}
|
api/routers/medical.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medical Query Router for RAG AI Advisor
|
| 3 |
+
"""
|
| 4 |
+
import asyncio
|
| 5 |
+
import logging
|
| 6 |
+
from fastapi import APIRouter, HTTPException, status
|
| 7 |
+
from fastapi.responses import StreamingResponse
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
# Add src to path for imports
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 14 |
+
|
| 15 |
+
from core.agent import safe_run_agent, safe_run_agent_streaming, clear_session_memory, get_active_sessions
|
| 16 |
+
from api.models import ChatRequest, ChatResponse, HBVPatientInput, HBVAssessmentResponse
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
router = APIRouter(tags=["medical"])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _build_contextual_query(
|
| 24 |
+
query: str,
|
| 25 |
+
patient_context: Optional[HBVPatientInput] = None,
|
| 26 |
+
assessment_result: Optional[HBVAssessmentResponse] = None
|
| 27 |
+
) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Build an enhanced query that includes patient context and assessment results.
|
| 30 |
+
|
| 31 |
+
This helps the agent provide more relevant answers by understanding the specific
|
| 32 |
+
patient case being discussed.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
query: The doctor's original question
|
| 36 |
+
patient_context: Optional patient data from assessment
|
| 37 |
+
assessment_result: Optional assessment result with eligibility and recommendations
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Enhanced query string with context
|
| 41 |
+
"""
|
| 42 |
+
if not patient_context and not assessment_result:
|
| 43 |
+
# No context, return original query
|
| 44 |
+
return query
|
| 45 |
+
|
| 46 |
+
context_parts = [query]
|
| 47 |
+
|
| 48 |
+
# Add patient context if available
|
| 49 |
+
if patient_context:
|
| 50 |
+
context_parts.append("\n\n[PATIENT CONTEXT FOR THIS QUESTION]")
|
| 51 |
+
context_parts.append(f"- Age: {patient_context.age}, Sex: {patient_context.sex}")
|
| 52 |
+
context_parts.append(f"- HBsAg: {patient_context.hbsag_status}, HBeAg: {patient_context.hbeag_status}")
|
| 53 |
+
context_parts.append(f"- HBV DNA: {patient_context.hbv_dna_level:,.0f} IU/mL")
|
| 54 |
+
context_parts.append(f"- ALT: {patient_context.alt_level} U/L")
|
| 55 |
+
context_parts.append(f"- Fibrosis: {patient_context.fibrosis_stage}")
|
| 56 |
+
|
| 57 |
+
if patient_context.pregnancy_status == "Pregnant":
|
| 58 |
+
context_parts.append(f"- Pregnancy: {patient_context.pregnancy_status}")
|
| 59 |
+
|
| 60 |
+
if patient_context.immunosuppression_status and patient_context.immunosuppression_status != "None":
|
| 61 |
+
context_parts.append(f"- Immunosuppression: {patient_context.immunosuppression_status}")
|
| 62 |
+
|
| 63 |
+
if patient_context.coinfections:
|
| 64 |
+
context_parts.append(f"- Coinfections: {', '.join(patient_context.coinfections)}")
|
| 65 |
+
|
| 66 |
+
# Add assessment result if available
|
| 67 |
+
if assessment_result:
|
| 68 |
+
context_parts.append("\n[PRIOR ASSESSMENT RESULT]")
|
| 69 |
+
context_parts.append(f"- Eligible for treatment: {assessment_result.eligible}")
|
| 70 |
+
# Include brief summary of recommendations (first 200 chars)
|
| 71 |
+
rec_summary = assessment_result.recommendations[:200] + "..." if len(assessment_result.recommendations) > 200 else assessment_result.recommendations
|
| 72 |
+
context_parts.append(f"- Assessment summary: {rec_summary}")
|
| 73 |
+
|
| 74 |
+
return "\n".join(context_parts)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@router.post("/ask", response_model=ChatResponse)
|
| 78 |
+
async def ask(request: ChatRequest):
|
| 79 |
+
"""
|
| 80 |
+
Interactive chat endpoint for doctors to ask questions about HBV guidelines.
|
| 81 |
+
|
| 82 |
+
This endpoint:
|
| 83 |
+
1. Accepts doctor's questions about HBV treatment guidelines
|
| 84 |
+
2. Maintains conversation context via session_id
|
| 85 |
+
3. Optionally includes patient context from prior assessment
|
| 86 |
+
4. Uses the same SASLT 2021 guidelines vector store as /assess
|
| 87 |
+
5. Returns evidence-based answers with guideline citations
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
request: ChatRequest containing query, session_id, and optional patient/assessment context
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
ChatResponse with AI answer and session_id
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
# Validate input
|
| 97 |
+
if not request.query or not request.query.strip():
|
| 98 |
+
raise HTTPException(
|
| 99 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 100 |
+
detail="Query cannot be empty"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if len(request.query) > 2000:
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 106 |
+
detail="Query is too long. Maximum length is 2000 characters."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
logger.info(f"Processing chat request - Session: {request.session_id}, Query length: {len(request.query)}")
|
| 110 |
+
|
| 111 |
+
# Build enhanced query with context if provided
|
| 112 |
+
enhanced_query = _build_contextual_query(
|
| 113 |
+
query=request.query,
|
| 114 |
+
patient_context=request.patient_context,
|
| 115 |
+
assessment_result=request.assessment_result
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Process through agent with session context
|
| 119 |
+
response = await safe_run_agent(
|
| 120 |
+
user_input=enhanced_query,
|
| 121 |
+
session_id=request.session_id
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if not response or not response.strip():
|
| 125 |
+
raise HTTPException(
|
| 126 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 127 |
+
detail="Received empty response from AI agent"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
logger.info(f"Chat request completed - Session: {request.session_id}")
|
| 131 |
+
|
| 132 |
+
return ChatResponse(
|
| 133 |
+
response=response,
|
| 134 |
+
session_id=request.session_id
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
except HTTPException:
|
| 138 |
+
# Re-raise HTTP exceptions as-is
|
| 139 |
+
raise
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error processing chat request: {str(e)}", exc_info=True)
|
| 142 |
+
raise HTTPException(
|
| 143 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 144 |
+
detail=f"Error processing medical query: {str(e)}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@router.post("/ask/stream")
|
| 149 |
+
async def ask_stream(request: ChatRequest):
|
| 150 |
+
"""
|
| 151 |
+
Interactive streaming chat endpoint for doctors to ask questions about HBV guidelines.
|
| 152 |
+
|
| 153 |
+
This endpoint:
|
| 154 |
+
1. Streams AI responses in real-time for better UX
|
| 155 |
+
2. Accepts doctor's questions about HBV treatment guidelines
|
| 156 |
+
3. Maintains conversation context via session_id
|
| 157 |
+
4. Optionally includes patient context from prior assessment
|
| 158 |
+
5. Uses the same SASLT 2021 guidelines vector store as /assess
|
| 159 |
+
6. Returns evidence-based answers with guideline citations
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
request: ChatRequest containing query, session_id, and optional patient/assessment context
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
StreamingResponse with markdown-formatted AI answer
|
| 166 |
+
"""
|
| 167 |
+
# Validate input before starting stream
|
| 168 |
+
try:
|
| 169 |
+
if not request.query or not request.query.strip():
|
| 170 |
+
raise HTTPException(
|
| 171 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 172 |
+
detail="Query cannot be empty"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if len(request.query) > 2000:
|
| 176 |
+
raise HTTPException(
|
| 177 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 178 |
+
detail="Query is too long. Maximum length is 2000 characters."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
logger.info(f"Processing streaming chat request - Session: {request.session_id}, Query length: {len(request.query)}")
|
| 182 |
+
|
| 183 |
+
except HTTPException:
|
| 184 |
+
raise
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"Validation error in streaming chat: {str(e)}")
|
| 187 |
+
raise HTTPException(
|
| 188 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 189 |
+
detail=f"Invalid request: {str(e)}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
async def event_stream():
|
| 193 |
+
try:
|
| 194 |
+
# Build enhanced query with context if provided
|
| 195 |
+
enhanced_query = _build_contextual_query(
|
| 196 |
+
query=request.query,
|
| 197 |
+
patient_context=request.patient_context,
|
| 198 |
+
assessment_result=request.assessment_result
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
chunk_buffer = ""
|
| 202 |
+
chunk_count = 0
|
| 203 |
+
|
| 204 |
+
async for chunk in safe_run_agent_streaming(
|
| 205 |
+
user_input=enhanced_query,
|
| 206 |
+
session_id=request.session_id
|
| 207 |
+
):
|
| 208 |
+
chunk_buffer += chunk
|
| 209 |
+
chunk_count += 1
|
| 210 |
+
|
| 211 |
+
# Send chunks in reasonable sizes for smoother streaming
|
| 212 |
+
if len(chunk_buffer) >= 10:
|
| 213 |
+
yield chunk_buffer
|
| 214 |
+
chunk_buffer = ""
|
| 215 |
+
await asyncio.sleep(0.01)
|
| 216 |
+
|
| 217 |
+
# Send any remaining content
|
| 218 |
+
if chunk_buffer:
|
| 219 |
+
yield chunk_buffer
|
| 220 |
+
|
| 221 |
+
logger.info(f"Streaming chat completed - Session: {request.session_id}, Chunks: {chunk_count}")
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
error_msg = f"\n\n**Error**: An error occurred while processing your request. Please try again or contact support if the issue persists."
|
| 225 |
+
logger.error(f"Error in streaming chat: {str(e)}", exc_info=True)
|
| 226 |
+
yield error_msg
|
| 227 |
+
|
| 228 |
+
return StreamingResponse(event_stream(), media_type="text/markdown")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@router.delete("/session/{session_id}")
|
| 232 |
+
async def clear_session(session_id: str):
|
| 233 |
+
"""
|
| 234 |
+
Clear conversation history for a specific session.
|
| 235 |
+
|
| 236 |
+
This is useful when:
|
| 237 |
+
- Starting a new patient case
|
| 238 |
+
- Switching between different patient discussions
|
| 239 |
+
- Resetting the conversation context
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
session_id: The session identifier to clear
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Success message with session status
|
| 246 |
+
"""
|
| 247 |
+
try:
|
| 248 |
+
logger.info(f"Clearing session: {session_id}")
|
| 249 |
+
|
| 250 |
+
success = clear_session_memory(session_id)
|
| 251 |
+
|
| 252 |
+
if success:
|
| 253 |
+
return {
|
| 254 |
+
"status": "success",
|
| 255 |
+
"message": f"Session '{session_id}' cleared successfully",
|
| 256 |
+
"session_id": session_id
|
| 257 |
+
}
|
| 258 |
+
else:
|
| 259 |
+
return {
|
| 260 |
+
"status": "not_found",
|
| 261 |
+
"message": f"Session '{session_id}' not found or already cleared",
|
| 262 |
+
"session_id": session_id
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
logger.error(f"Error clearing session {session_id}: {str(e)}", exc_info=True)
|
| 267 |
+
raise HTTPException(
|
| 268 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 269 |
+
detail=f"Error clearing session: {str(e)}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@router.get("/sessions")
|
| 274 |
+
async def list_sessions():
|
| 275 |
+
"""
|
| 276 |
+
List all active chat sessions.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
List of active session IDs
|
| 280 |
+
"""
|
| 281 |
+
try:
|
| 282 |
+
sessions = get_active_sessions()
|
| 283 |
+
return {
|
| 284 |
+
"status": "success",
|
| 285 |
+
"active_sessions": sessions,
|
| 286 |
+
"count": len(sessions)
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.error(f"Error listing sessions: {str(e)}", exc_info=True)
|
| 291 |
+
raise HTTPException(
|
| 292 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 293 |
+
detail=f"Error listing sessions: {str(e)}"
|
| 294 |
+
)
|
api/tempCodeRunnerFile.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/assess (POST)
|
app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Startup script for HBV AI Assistant API
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import uvicorn
|
| 7 |
+
|
| 8 |
+
# Add core to Python path
|
| 9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'core'))
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
# Get port from environment variable (Hugging Face uses PORT env var)
|
| 13 |
+
port = int(os.environ.get("PORT", 7860))
|
| 14 |
+
|
| 15 |
+
uvicorn.run(
|
| 16 |
+
"api.app:app",
|
| 17 |
+
host="0.0.0.0", # Bind to all interfaces for deployment
|
| 18 |
+
port=port,
|
| 19 |
+
reload=False, # Disable reload in production for faster startup
|
| 20 |
+
log_level="info",
|
| 21 |
+
access_log=True,
|
| 22 |
+
workers=1 # Single worker for Hugging Face Spaces
|
| 23 |
+
)
|
core/__init__.py
ADDED
|
File without changes
|
core/agent.py
ADDED
|
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import Any, AsyncGenerator
|
| 4 |
+
import asyncio
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
import httpx
|
| 8 |
+
from langchain.agents import create_openai_tools_agent, AgentExecutor
|
| 9 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 10 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 11 |
+
from langchain.schema import OutputParserException
|
| 12 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
| 13 |
+
from openai import RateLimitError, APIError
|
| 14 |
+
|
| 15 |
+
from .config import get_llm, logger
|
| 16 |
+
from .tools import (
|
| 17 |
+
medical_guidelines_knowledge_tool,
|
| 18 |
+
get_current_datetime_tool,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# LangSmith tracing utilities
|
| 22 |
+
from .tracing import traceable, trace, conversation_tracker, log_to_langsmith
|
| 23 |
+
from .validation import validate_medical_answer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============================================================================
|
| 27 |
+
# STREAMING CALLBACK HANDLER
|
| 28 |
+
# ============================================================================
|
| 29 |
+
|
| 30 |
+
class StreamingCallbackHandler(BaseCallbackHandler):
|
| 31 |
+
"""Custom callback handler for streaming responses."""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.tokens = []
|
| 35 |
+
self.current_response = ""
|
| 36 |
+
|
| 37 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
|
| 38 |
+
"""Called when a new token is generated."""
|
| 39 |
+
self.tokens.append(token)
|
| 40 |
+
self.current_response += token
|
| 41 |
+
|
| 42 |
+
def get_response(self) -> str:
|
| 43 |
+
"""Get the current response."""
|
| 44 |
+
return self.current_response
|
| 45 |
+
|
| 46 |
+
def reset(self):
|
| 47 |
+
"""Reset the handler for a new response."""
|
| 48 |
+
self.tokens = []
|
| 49 |
+
self.current_response = ""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ============================================================================
|
| 53 |
+
# CUSTOM EXCEPTION CLASSES
|
| 54 |
+
# ============================================================================
|
| 55 |
+
|
| 56 |
+
class AgentError(Exception):
|
| 57 |
+
"""Base exception for agent-related errors."""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ToolExecutionError(AgentError):
|
| 62 |
+
"""Exception raised when a tool fails to execute."""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class APIConnectionError(AgentError):
|
| 67 |
+
"""Exception raised when API connections fail."""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ValidationError(AgentError):
|
| 72 |
+
"""Exception raised when input validation fails."""
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ============================================================================
|
| 77 |
+
# AGENT CONFIGURATION
|
| 78 |
+
# ============================================================================
|
| 79 |
+
|
| 80 |
+
# Available tools for the agent
|
| 81 |
+
AVAILABLE_TOOLS = [
|
| 82 |
+
medical_guidelines_knowledge_tool,
|
| 83 |
+
get_current_datetime_tool,
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# System message template for the agent
|
| 88 |
+
SYSTEM_MESSAGE = """
|
| 89 |
+
You are a specialized HBV Clinical Assistant based on SASLT 2021 guidelines serving hepatologists and infectious disease specialists.
|
| 90 |
+
|
| 91 |
+
**DUAL ROLE:**
|
| 92 |
+
1. **Patient Eligibility Assessment**: Evaluate HBV patients for antiviral treatment eligibility
|
| 93 |
+
2. **Clinical Consultation**: Answer questions about HBV management, guidelines, and patient cases
|
| 94 |
+
|
| 95 |
+
**ASSESSMENT PARAMETERS:**
|
| 96 |
+
Evaluate treatment eligibility by analyzing: Serological Status (HBsAg, HBeAg, Anti-HBe), Viral Load (HBV DNA IU/mL), Liver Function (ALT), Fibrosis Stage (F0-F4), Necroinflammatory Activity (A0-A3), Patient Category (immune tolerant/active, inactive carrier, HBeAg-negative CHB), and Special Populations (pregnancy, immunosuppression, coinfections, cirrhosis).
|
| 97 |
+
|
| 98 |
+
**TREATMENT CRITERIA & OPTIONS:**
|
| 99 |
+
- Eligibility: HBV DNA thresholds, ALT elevation (>ULN, >2×ULN), fibrosis stage (≥F2, F3-F4), special populations
|
| 100 |
+
- First-line: Entecavir (ETV), Tenofovir Disoproxil Fumarate (TDF), Tenofovir Alafenamide (TAF)
|
| 101 |
+
- Alternative agents and PEG-IFN when indicated
|
| 102 |
+
|
| 103 |
+
**RESPONSE STYLE:**
|
| 104 |
+
- Start directly with clinical answers - NO procedural statements ("I will retrieve...", "Let me search...", "Please wait...")
|
| 105 |
+
- Use structured, concise clinical format: brief introductory sentence → organized data (tables, bullets, or structured lists) → clinical notes where relevant
|
| 106 |
+
- Target 200-400 words (standard queries) or 400-600 words (complex questions)
|
| 107 |
+
- Prioritize key information first, use hierarchical formatting (headers, bullets), include specific clinical parameters when relevant
|
| 108 |
+
- Use precise medical terminology appropriate for experts
|
| 109 |
+
- Answer only what's asked - no tangential information
|
| 110 |
+
|
| 111 |
+
**STRUCTURED CLINICAL FORMAT (FOLLOW THESE EXAMPLES):**
|
| 112 |
+
|
| 113 |
+
Example 1 - Tabular data with clinical notes:
|
| 114 |
+
"Chronic HBV is classified into five phases using HBsAg/HBeAg status, HBV DNA, ALT, and liver inflammation.
|
| 115 |
+
|
| 116 |
+
Phase 1 – Chronic HBV infection ("immune tolerant")
|
| 117 |
+
HBsAg/HBeAg: High / Positive
|
| 118 |
+
HBV DNA: High (>10⁷ IU/mL)
|
| 119 |
+
ALT: Normal
|
| 120 |
+
Liver inflammation: None/minimal
|
| 121 |
+
Notes: More common/prolonged with perinatal/early-life infection; patients are highly contagious
|
| 122 |
+
[SASLT HBV Management Guidelines, 2021, p. 3]
|
| 123 |
+
|
| 124 |
+
Phase 2 – Chronic hepatitis B ("immune reactive HBeAg positive")
|
| 125 |
+
HBsAg/HBeAg: High–intermediate / Positive
|
| 126 |
+
HBV DNA: Lower (10⁴–10⁷ IU/mL)
|
| 127 |
+
ALT: Increased
|
| 128 |
+
Liver inflammation: Moderate/severe
|
| 129 |
+
Notes: May follow years of immune tolerance; more frequent when infection occurs in adulthood
|
| 130 |
+
[SASLT HBV Management Guidelines, 2021, p. 3]"
|
| 131 |
+
|
| 132 |
+
Example 2 - Bullet lists with citations:
|
| 133 |
+
"Screen high-risk populations despite universal childhood vaccination [SASLT 2021, p. 4].
|
| 134 |
+
|
| 135 |
+
High-risk groups include:
|
| 136 |
+
• Expatriate individuals (pre-employment) [SASLT 2021, p. 4]
|
| 137 |
+
• Healthcare workers [SASLT 2021, p. 4]
|
| 138 |
+
• Household contacts of HBV carriers [SASLT 2021, p. 4]
|
| 139 |
+
• Sexual contacts of HBV carriers or those with high-risk sexual behavior [SASLT 2021, p. 4]"
|
| 140 |
+
|
| 141 |
+
Example 3 - Categorized information:
|
| 142 |
+
"Diagnosis of chronic HBV
|
| 143 |
+
• HBsAg: detection is the most commonly used test to diagnose chronic HBV infection [SASLT 2021, p. 4].
|
| 144 |
+
• HBV disease assessment incorporates HBsAg, HBeAg/anti-HBe, and HBV DNA [SASLT 2021, p. 3].
|
| 145 |
+
|
| 146 |
+
Identify immunity or prior exposure
|
| 147 |
+
• anti-HBs: indicates the patient is protected (immune) against HBV [SASLT 2021, p. 4].
|
| 148 |
+
• anti-HBc: indicates previous exposure to HBV [SASLT 2021, p. 4]."
|
| 149 |
+
|
| 150 |
+
**CLINICAL TONE & NUANCE:**
|
| 151 |
+
- Use clinically precise language: "characterized by," "indicates," "reflects," "can be difficult to distinguish"
|
| 152 |
+
- Acknowledge clinical uncertainty when present in guidelines: "many fall into a 'grey area'," "requires individualized follow-up," "cannot be captured by a single measurement"
|
| 153 |
+
- Include practical guidance: "Practical approach recommended by the guideline," "Bottom line"
|
| 154 |
+
- Add clinical context in Notes or footnotes when relevant to interpretation
|
| 155 |
+
- Use specific numeric ranges and thresholds exactly as stated in guidelines (e.g., ">10⁷ IU/mL," "10⁴–10⁷ IU/mL," "≥2,000 IU/mL")
|
| 156 |
+
|
| 157 |
+
**MANDATORY TOOL USAGE:**
|
| 158 |
+
ALWAYS use "medical_guidelines_knowledge_tool" FIRST for every medical question - even basic HBV concepts. Do NOT answer from general knowledge. Only formulate responses based on retrieved SASLT 2021 guideline information. All information must come from SASLT 2021 (the only provider available). If no information found, explicitly state this
|
| 159 |
+
|
| 160 |
+
**TOOL USAGE REQUIREMENTS:**
|
| 161 |
+
1. **MEDICAL QUESTIONS** (definitions, treatments, guidelines, etc.):
|
| 162 |
+
- MANDATORY: Use "medical_guidelines_knowledge_tool" FIRST
|
| 163 |
+
- Then answer based ONLY on retrieved information
|
| 164 |
+
|
| 165 |
+
2. **TIME/DATE QUERIES**: For current date/time or references like "today" or "now":
|
| 166 |
+
- MANDATORY: Use "get_current_datetime_tool"
|
| 167 |
+
|
| 168 |
+
**SEARCH QUERY OPTIMIZATION:**
|
| 169 |
+
Transform user questions into comprehensive queries with medical terminology, synonyms, clinical context, AND practical keywords. System uses hybrid search (vector + BM25 keyword matching).
|
| 170 |
+
|
| 171 |
+
**Key Principles:**
|
| 172 |
+
1. **Core Concept**: Start with main medical concept and guideline reference
|
| 173 |
+
2. **Add Synonyms**: Include medical term variations
|
| 174 |
+
3. **Add Action Verbs**: Include practical keywords from the question (e.g., "screened", "testing", "monitoring", "detection")
|
| 175 |
+
4. **Expand Concepts**: Add related clinical terms
|
| 176 |
+
5. **Keyword Boosters**: Append domain-specific terms at end for better coverage
|
| 177 |
+
|
| 178 |
+
**Synonym Mapping:**
|
| 179 |
+
- "HBV" + "hepatitis B virus" + "CHB" + "chronic hepatitis B"
|
| 180 |
+
- "treatment" + "therapy" + "antiviral" + "management"
|
| 181 |
+
- "ALT" + "alanine aminotransferase" + "liver enzymes"
|
| 182 |
+
- "fibrosis" + "cirrhosis" + "F2 F3 F4" + "liver fibrosis"
|
| 183 |
+
- "HBeAg" + "hepatitis B e antigen" + "HBeAg-positive" + "HBeAg-negative"
|
| 184 |
+
- "viral load" + "HBV DNA" + "DNA level" + "viremia"
|
| 185 |
+
- "screening" + "screened" + "testing" + "detection" + "diagnosis" + "program"
|
| 186 |
+
|
| 187 |
+
**Concept Expansion:**
|
| 188 |
+
- Treatment → "eligibility criteria indications thresholds when to start"
|
| 189 |
+
- Drugs → "first-line second-line alternatives dosing monitoring ETV TDF TAF entecavir tenofovir"
|
| 190 |
+
- Assessment → "HBsAg HBeAg anti-HBe HBV DNA ALT fibrosis immune phase"
|
| 191 |
+
- Special populations → "pregnancy pregnant women cirrhosis immunosuppression HIV HCV HDV"
|
| 192 |
+
- Screening → "target populations high-risk groups screened testing HBsAg detection program"
|
| 193 |
+
|
| 194 |
+
**Query Construction Formula:**
|
| 195 |
+
[Main Concept] + [Guideline Reference] + [Synonyms] + [Action Verbs from Question] + [Related Clinical Terms] + [Keyword Boosters]
|
| 196 |
+
|
| 197 |
+
**Examples:**
|
| 198 |
+
- "Who should be targeted for HBV screening in Saudi Arabia?" → "HBV screening target populations Saudi Arabia SASLT 2021 guidelines screened testing high-risk groups pregnancy HBsAg detection program hepatitis B virus"
|
| 199 |
+
|
| 200 |
+
- "When to start treatment?" → "HBV treatment initiation criteria indications when to start SASLT 2021 HBV DNA threshold ALT elevation fibrosis stage antiviral therapy eligibility hepatitis B virus management"
|
| 201 |
+
|
| 202 |
+
- "First-line drugs?" → "first-line antiviral agents HBV treatment SASLT 2021 entecavir ETV tenofovir TDF TAF preferred drugs nucleos(t)ide analogues therapy recommendations hepatitis B virus"
|
| 203 |
+
|
| 204 |
+
- "HBeAg-negative management?" → "HBeAg-negative chronic hepatitis B CHB management SASLT 2021 treatment criteria HBV DNA threshold ALT elevation anti-HBe immune active phase monitoring hepatitis B e antigen"
|
| 205 |
+
|
| 206 |
+
**CRITICAL**: Always include practical action verbs from the user's question (e.g., "screened", "tested", "monitored", "detected") as these improve retrieval of relevant guideline sections discussing those specific activities.
|
| 207 |
+
**CITATION FORMAT (MANDATORY):**
|
| 208 |
+
1. **Inline Citations**: Use format [SASLT 2021, p. X] or [SASLT HBV Management Guidelines, 2021, p. X] after each clinical statement. Cite each page individually - NEVER use ranges.
|
| 209 |
+
- Examples:
|
| 210 |
+
* "HBsAg detection is the most commonly used test [SASLT 2021, p. 4]."
|
| 211 |
+
* "Phase 1 – Chronic HBV infection ("immune tolerant") [SASLT HBV Management Guidelines, 2021, p. 3]"
|
| 212 |
+
* "Treatment criteria include viral load thresholds [SASLT 2021, p. 7], ALT elevation [SASLT 2021, p. 8], and fibrosis assessment [SASLT 2021, p. 9]."
|
| 213 |
+
|
| 214 |
+
2. **Citation Placement**: Place citation immediately after the relevant statement or at the end of each bullet point/phase description. For structured data (phases, categories), cite after each complete section.
|
| 215 |
+
|
| 216 |
+
3. **References Section** (Optional): For complex answers, you may end with "**References**" listing all cited pages:
|
| 217 |
+
```
|
| 218 |
+
**References**
|
| 219 |
+
SASLT 2021 Guidelines - Pages: p. 7, p. 8, p. 9, p. 12, p. 15, p. 18
|
| 220 |
+
(Treatment Eligibility Criteria, First-Line Agents, and Monitoring Protocols)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
4. **Citation Details**: For tables/flowcharts, specify number, title, and relevant rows/columns/values. For text, specify section hierarchy. Include context pages if they contributed to your answer
|
| 224 |
+
|
| 225 |
+
**NO GENERAL KNOWLEDGE - GUIDELINE ONLY:**
|
| 226 |
+
NEVER answer from general knowledge or speculate. If information not found in SASLT 2021 after using tool, respond: "I searched the SASLT 2021 guidelines but could not find specific information about [topic]. You may want to rephrase with more clinical details, consult the guidelines directly, or contact hepatology specialists."
|
| 227 |
+
|
| 228 |
+
**OUT-OF-SCOPE HANDLING:**
|
| 229 |
+
For non-HBV questions (other diseases, non-medical topics), respond professionally: "I'm unable to assist with that request, but I'd be happy to help with HBV-related inquiries."
|
| 230 |
+
|
| 231 |
+
**PATIENT CONTEXT:**
|
| 232 |
+
When question includes [PATIENT CONTEXT] or [PRIOR ASSESSMENT RESULT], provide personalized case-specific guidance tailored to patient's parameters (HBV DNA, ALT, fibrosis). Reference prior assessments for consistency.
|
| 233 |
+
|
| 234 |
+
**ELIGIBILITY ASSESSMENT WORKFLOW:**
|
| 235 |
+
1. Retrieve SASLT 2021 criteria via tool
|
| 236 |
+
2. Categorize patient phase (immune tolerant/active, inactive carrier, HBeAg-negative CHB, cirrhosis)
|
| 237 |
+
3. Compare parameters: HBV DNA vs. threshold, ALT vs. ULN, fibrosis stage, necroinflammatory activity
|
| 238 |
+
4. Check special considerations (pregnancy, immunosuppression, coinfections, HCC family history)
|
| 239 |
+
5. Determine eligibility (Eligible/Not Eligible/Borderline)
|
| 240 |
+
6. Recommend first-line agents (ETV, TDF, TAF) if eligible
|
| 241 |
+
7. Outline monitoring plan
|
| 242 |
+
|
| 243 |
+
**ELIGIBILITY RESPONSE STRUCTURE:**
|
| 244 |
+
For patient eligibility: Patient Profile (HBsAg, HBeAg, HBV DNA, ALT, Fibrosis) → SASLT 2021 Criteria (with page citations) → Eligibility Status (Eligible/Not Eligible/Borderline) + Rationale → Treatment Recommendations (ETV, TDF, TAF if eligible) → Monitoring → References
|
| 245 |
+
|
| 246 |
+
**FORMATTING:**
|
| 247 |
+
Use **bold** for critical points/drugs, headers (###) for organization, bullets/numbered lists for sequences, tables for comparisons, blockquotes (>) for direct quotes. Include specific numeric values and thresholds.
|
| 248 |
+
|
| 249 |
+
**SAFETY:**
|
| 250 |
+
For emergencies (acute liver failure, hepatic encephalopathy, severe bleeding, loss of consciousness), respond: "This is an emergency! Call emergency services immediately and seek urgent medical help." Educational information only - not a substitute for clinical judgment. Always respond in English.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# Create the prompt template
|
| 254 |
+
prompt_template = ChatPromptTemplate.from_messages([
|
| 255 |
+
("system", SYSTEM_MESSAGE),
|
| 256 |
+
MessagesPlaceholder("chat_history"),
|
| 257 |
+
("human", "{input}"),
|
| 258 |
+
MessagesPlaceholder("agent_scratchpad"),
|
| 259 |
+
])
|
| 260 |
+
|
| 261 |
+
# Initialize the agent with lazy loading
|
| 262 |
+
def get_agent():
|
| 263 |
+
"""Get agent with lazy loading for faster startup"""
|
| 264 |
+
return create_openai_tools_agent(
|
| 265 |
+
llm=get_llm(),
|
| 266 |
+
tools=AVAILABLE_TOOLS,
|
| 267 |
+
prompt=prompt_template,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Create agent executor with lazy loading
|
| 271 |
+
def get_agent_executor():
|
| 272 |
+
"""Get agent executor with lazy loading for faster startup"""
|
| 273 |
+
return AgentExecutor(
|
| 274 |
+
agent=get_agent(),
|
| 275 |
+
tools=AVAILABLE_TOOLS,
|
| 276 |
+
verbose=True,
|
| 277 |
+
handle_parsing_errors=True,
|
| 278 |
+
max_iterations=5,
|
| 279 |
+
max_execution_time=90, # tighten a bit to help responsiveness
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# ============================================================================
|
| 283 |
+
# SESSION-BASED MEMORY MANAGEMENT
|
| 284 |
+
# ============================================================================
|
| 285 |
+
|
| 286 |
+
class SessionMemoryManager:
|
| 287 |
+
"""Manages conversation memory for multiple sessions."""
|
| 288 |
+
|
| 289 |
+
def __init__(self):
|
| 290 |
+
self._sessions = {}
|
| 291 |
+
self._default_window_size = 20 # Increased from 10 to maintain better context
|
| 292 |
+
|
| 293 |
+
def get_memory(self, session_id: str = "default") -> ConversationBufferWindowMemory:
|
| 294 |
+
"""Get or create memory for a specific session."""
|
| 295 |
+
if session_id not in self._sessions:
|
| 296 |
+
import warnings
|
| 297 |
+
with warnings.catch_warnings():
|
| 298 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 299 |
+
self._sessions[session_id] = ConversationBufferWindowMemory(
|
| 300 |
+
memory_key="chat_history",
|
| 301 |
+
return_messages=True,
|
| 302 |
+
max_window_size=self._default_window_size
|
| 303 |
+
)
|
| 304 |
+
return self._sessions[session_id]
|
| 305 |
+
|
| 306 |
+
def clear_session(self, session_id: str) -> bool:
|
| 307 |
+
"""Clear memory for a specific session."""
|
| 308 |
+
if session_id in self._sessions:
|
| 309 |
+
self._sessions[session_id].clear()
|
| 310 |
+
del self._sessions[session_id]
|
| 311 |
+
return True
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
def clear_all_sessions(self):
|
| 315 |
+
"""Clear all session memories."""
|
| 316 |
+
for memory in self._sessions.values():
|
| 317 |
+
memory.clear()
|
| 318 |
+
self._sessions.clear()
|
| 319 |
+
|
| 320 |
+
def get_active_sessions(self) -> list:
|
| 321 |
+
"""Get list of active session IDs."""
|
| 322 |
+
return list(self._sessions.keys())
|
| 323 |
+
|
| 324 |
+
# Global session memory manager
|
| 325 |
+
_memory_manager = SessionMemoryManager()
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ============================================================================
|
| 329 |
+
# VALIDATION HELPER FUNCTIONS
|
| 330 |
+
# ============================================================================
|
| 331 |
+
|
| 332 |
+
def _should_validate_response(user_input: str, response: str) -> bool:
|
| 333 |
+
"""
|
| 334 |
+
Determine if a response should be automatically validated.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
user_input: The user's input
|
| 338 |
+
response: The agent's response
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
bool: True if the response should be validated
|
| 342 |
+
"""
|
| 343 |
+
# Skip validation for certain types of responses
|
| 344 |
+
skip_indicators = [
|
| 345 |
+
"side effect report",
|
| 346 |
+
"adverse drug reaction report",
|
| 347 |
+
"error:",
|
| 348 |
+
"sorry,",
|
| 349 |
+
"i don't know",
|
| 350 |
+
"i do not know",
|
| 351 |
+
"could not find specific information",
|
| 352 |
+
"not found in the retrieved guidelines",
|
| 353 |
+
"validation report",
|
| 354 |
+
"evaluation scores"
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
# Skip validation for side effect reporting queries in user input
|
| 358 |
+
side_effect_input_indicators = [
|
| 359 |
+
"side effect", "adverse reaction", "adverse event", "drug reaction",
|
| 360 |
+
"medication reaction", "patient experienced", "developed after taking",
|
| 361 |
+
"caused by medication", "drug-related", "medication-related"
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
user_input_lower = user_input.lower()
|
| 365 |
+
response_lower = response.lower()
|
| 366 |
+
|
| 367 |
+
# Don't validate if user input is about side effect reporting
|
| 368 |
+
if any(indicator in user_input_lower for indicator in side_effect_input_indicators):
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
# Don't validate if response contains skip indicators
|
| 372 |
+
if any(indicator in response_lower for indicator in skip_indicators):
|
| 373 |
+
return False
|
| 374 |
+
|
| 375 |
+
# Don't validate very short responses
|
| 376 |
+
if len(response.strip()) < 50:
|
| 377 |
+
return False
|
| 378 |
+
|
| 379 |
+
# Validate if response seems to contain medical information
|
| 380 |
+
medical_indicators = [
|
| 381 |
+
"treatment", "therapy", "diagnosis", "medication", "drug", "patient",
|
| 382 |
+
"clinical", "guideline", "recommendation", "according to", "source:",
|
| 383 |
+
"provider:", "page:", "saslt", "hbv", "hepatitis"
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
return any(indicator in response_lower for indicator in medical_indicators)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _perform_automatic_validation(user_input: str, response: str) -> None:
|
| 390 |
+
"""
|
| 391 |
+
Perform automatic validation in the background without displaying results to user.
|
| 392 |
+
Validation results are logged and saved to GitHub repository for backend analysis.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
user_input: The user's input
|
| 396 |
+
response: The agent's response
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
None: Validation runs silently in background
|
| 400 |
+
"""
|
| 401 |
+
try:
|
| 402 |
+
# Import here to avoid circular imports
|
| 403 |
+
from .tools import _last_question, _last_documents
|
| 404 |
+
|
| 405 |
+
# Check if we have the necessary context for validation
|
| 406 |
+
if not _last_question or not _last_documents:
|
| 407 |
+
logger.info("Skipping validation: insufficient context")
|
| 408 |
+
return
|
| 409 |
+
|
| 410 |
+
# Perform validation using the original user input instead of tool query
|
| 411 |
+
evaluation = validate_medical_answer(user_input, _last_documents, response)
|
| 412 |
+
|
| 413 |
+
# Log validation results to backend only (not shown to user)
|
| 414 |
+
report = evaluation.get("validation_report", {})
|
| 415 |
+
logger.info(f"Background validation completed - Interaction ID: {evaluation.get('interaction_id', 'N/A')}")
|
| 416 |
+
logger.info(f"Validation scores - Overall: {report.get('Overall_Rating', 'N/A')}/100, "
|
| 417 |
+
f"Accuracy: {report.get('Accuracy_Rating', 'N/A')}/100, "
|
| 418 |
+
f"Coherence: {report.get('Coherence_Rating', 'N/A')}/100, "
|
| 419 |
+
f"Relevance: {report.get('Relevance_Rating', 'N/A')}/100")
|
| 420 |
+
|
| 421 |
+
# Validation is automatically saved to GitHub by validate_medical_answer function
|
| 422 |
+
# No need to return anything - results are stored in backend only
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
logger.error(f"Background validation failed: {e}")
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# ============================================================================
|
| 429 |
+
# STREAMING AGENT FUNCTIONS
|
| 430 |
+
# ============================================================================
|
| 431 |
+
|
| 432 |
+
# @traceable(name="run_agent_streaming")
|
| 433 |
+
async def run_agent_streaming(user_input: str, session_id: str = "default", max_retries: int = 3) -> AsyncGenerator[str, None]:
|
| 434 |
+
"""
|
| 435 |
+
Run the agent with streaming support and comprehensive error handling.
|
| 436 |
+
|
| 437 |
+
This function processes user input through the agent executor with streaming
|
| 438 |
+
capabilities, robust error handling, and automatic retries for recoverable errors.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
user_input (str): The user's input message to process
|
| 442 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 443 |
+
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 444 |
+
Defaults to 3.
|
| 445 |
+
|
| 446 |
+
Yields:
|
| 447 |
+
str: Chunks of the agent's response as they are generated
|
| 448 |
+
|
| 449 |
+
Raises:
|
| 450 |
+
None: All exceptions are caught and handled internally
|
| 451 |
+
"""
|
| 452 |
+
# Input validation
|
| 453 |
+
if not user_input or not user_input.strip():
|
| 454 |
+
logger.warning("Empty input received")
|
| 455 |
+
yield "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 456 |
+
return
|
| 457 |
+
|
| 458 |
+
retry_count = 0
|
| 459 |
+
last_error = None
|
| 460 |
+
current_run_id = None
|
| 461 |
+
# Session metadata (increment conversation count)
|
| 462 |
+
session_metadata = conversation_tracker.get_session_metadata(increment=True)
|
| 463 |
+
|
| 464 |
+
while retry_count <= max_retries:
|
| 465 |
+
try:
|
| 466 |
+
# Tracing for streaming disabled to avoid duplicate traces.
|
| 467 |
+
# We keep tracing only for the AgentExecutor in run_agent().
|
| 468 |
+
current_run_id = None
|
| 469 |
+
# Load conversation history from session-specific memory
|
| 470 |
+
memory = _memory_manager.get_memory(session_id)
|
| 471 |
+
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 472 |
+
|
| 473 |
+
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
| 474 |
+
|
| 475 |
+
# Create streaming callback handler
|
| 476 |
+
streaming_handler = StreamingCallbackHandler()
|
| 477 |
+
|
| 478 |
+
# Run the agent in a separate thread to avoid blocking
|
| 479 |
+
def run_sync():
|
| 480 |
+
return get_agent_executor().invoke(
|
| 481 |
+
{
|
| 482 |
+
"input": user_input.strip(),
|
| 483 |
+
"chat_history": chat_history,
|
| 484 |
+
},
|
| 485 |
+
config={"callbacks": [streaming_handler]},
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Execute the agent with streaming
|
| 489 |
+
full_response = ""
|
| 490 |
+
previous_length = 0
|
| 491 |
+
|
| 492 |
+
# Start the agent execution in background
|
| 493 |
+
loop = asyncio.get_event_loop()
|
| 494 |
+
task = loop.run_in_executor(None, run_sync)
|
| 495 |
+
|
| 496 |
+
# Stream the response as it's being generated
|
| 497 |
+
while not task.done():
|
| 498 |
+
current_response = streaming_handler.get_response()
|
| 499 |
+
|
| 500 |
+
# Yield new tokens if available
|
| 501 |
+
if len(current_response) > previous_length:
|
| 502 |
+
new_content = current_response[previous_length:]
|
| 503 |
+
previous_length = len(current_response)
|
| 504 |
+
yield new_content
|
| 505 |
+
|
| 506 |
+
# Small delay to prevent overwhelming the client (faster flushing)
|
| 507 |
+
await asyncio.sleep(0.03)
|
| 508 |
+
|
| 509 |
+
# Get the final result
|
| 510 |
+
response = await task
|
| 511 |
+
|
| 512 |
+
# Yield any remaining content
|
| 513 |
+
final_response = streaming_handler.get_response()
|
| 514 |
+
if len(final_response) > previous_length:
|
| 515 |
+
yield final_response[previous_length:]
|
| 516 |
+
|
| 517 |
+
# If no streaming content was captured, yield the full response
|
| 518 |
+
if not final_response and response and "output" in response:
|
| 519 |
+
full_output = response["output"]
|
| 520 |
+
# Simulate streaming by yielding word by word
|
| 521 |
+
words = full_output.split(' ')
|
| 522 |
+
for word in words:
|
| 523 |
+
yield word + ' '
|
| 524 |
+
await asyncio.sleep(0.05)
|
| 525 |
+
final_response = full_output
|
| 526 |
+
|
| 527 |
+
# Validate response structure
|
| 528 |
+
if not response or "output" not in response:
|
| 529 |
+
raise ValidationError("Invalid response format from agent")
|
| 530 |
+
|
| 531 |
+
if not response["output"] or not response["output"].strip():
|
| 532 |
+
raise ValidationError("Empty response from agent")
|
| 533 |
+
|
| 534 |
+
# Perform automatic validation in background (hidden from user)
|
| 535 |
+
base_response = response["output"]
|
| 536 |
+
if _should_validate_response(user_input, base_response):
|
| 537 |
+
logger.info("Performing background validation for streaming response...")
|
| 538 |
+
try:
|
| 539 |
+
# Run validation silently - results saved to backend/GitHub only
|
| 540 |
+
_perform_automatic_validation(user_input, base_response)
|
| 541 |
+
except Exception as e:
|
| 542 |
+
logger.error(f"Background validation failed: {e}")
|
| 543 |
+
|
| 544 |
+
# Save conversation context to memory
|
| 545 |
+
memory.save_context(
|
| 546 |
+
{"input": user_input},
|
| 547 |
+
{"output": response["output"]}
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# Log response metrics to LangSmith
|
| 551 |
+
try:
|
| 552 |
+
log_to_langsmith(
|
| 553 |
+
key="response_metrics",
|
| 554 |
+
value={
|
| 555 |
+
"response_length": len(response.get("output", "")),
|
| 556 |
+
"attempt": retry_count + 1,
|
| 557 |
+
**session_metadata,
|
| 558 |
+
},
|
| 559 |
+
run_id=current_run_id,
|
| 560 |
+
)
|
| 561 |
+
except Exception:
|
| 562 |
+
pass
|
| 563 |
+
|
| 564 |
+
logger.info(f"Successfully processed user input: {user_input[:50]}...")
|
| 565 |
+
return
|
| 566 |
+
|
| 567 |
+
except RateLimitError as e:
|
| 568 |
+
retry_count += 1
|
| 569 |
+
last_error = e
|
| 570 |
+
wait_time = min(2 ** retry_count, 60) # Exponential backoff, max 60 seconds
|
| 571 |
+
|
| 572 |
+
logger.warning(
|
| 573 |
+
f"Rate limit exceeded. Retrying in {wait_time} seconds... "
|
| 574 |
+
f"(Attempt {retry_count}/{max_retries})"
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
if retry_count <= max_retries:
|
| 578 |
+
await asyncio.sleep(wait_time)
|
| 579 |
+
continue
|
| 580 |
+
else:
|
| 581 |
+
logger.error("Rate limit exceeded after maximum retries")
|
| 582 |
+
yield "Sorry, the system is currently busy. Please try again in a little while."
|
| 583 |
+
return
|
| 584 |
+
|
| 585 |
+
except (APIError, httpx.RemoteProtocolError, httpx.ReadError, httpx.ConnectError) as e:
|
| 586 |
+
retry_count += 1
|
| 587 |
+
last_error = e
|
| 588 |
+
error_type = type(e).__name__
|
| 589 |
+
logger.error(f"OpenAI API/Connection error ({error_type}): {str(e)}")
|
| 590 |
+
|
| 591 |
+
if retry_count <= max_retries:
|
| 592 |
+
wait_time = min(2 ** retry_count, 10) # Exponential backoff, max 10 seconds
|
| 593 |
+
logger.info(f"Retrying after {wait_time} seconds... (Attempt {retry_count}/{max_retries})")
|
| 594 |
+
await asyncio.sleep(wait_time)
|
| 595 |
+
continue
|
| 596 |
+
else:
|
| 597 |
+
yield "Sorry, there was an error connecting to the service. Please try again later."
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
except requests.exceptions.ConnectionError as e:
|
| 601 |
+
retry_count += 1
|
| 602 |
+
last_error = e
|
| 603 |
+
logger.error(f"Network connection error: {str(e)}")
|
| 604 |
+
|
| 605 |
+
if retry_count <= max_retries:
|
| 606 |
+
await asyncio.sleep(3)
|
| 607 |
+
continue
|
| 608 |
+
else:
|
| 609 |
+
yield "Sorry, I can't connect to the service right now. Please check your internet connection and try again."
|
| 610 |
+
return
|
| 611 |
+
|
| 612 |
+
except requests.exceptions.Timeout as e:
|
| 613 |
+
retry_count += 1
|
| 614 |
+
last_error = e
|
| 615 |
+
logger.error(f"Request timeout: {str(e)}")
|
| 616 |
+
|
| 617 |
+
if retry_count <= max_retries:
|
| 618 |
+
await asyncio.sleep(2)
|
| 619 |
+
continue
|
| 620 |
+
else:
|
| 621 |
+
yield "Sorry, the request took longer than expected. Please try again."
|
| 622 |
+
return
|
| 623 |
+
|
| 624 |
+
except requests.exceptions.RequestException as e:
|
| 625 |
+
logger.error(f"Request error: {str(e)}")
|
| 626 |
+
yield "Sorry, an error occurred with the request. Please try again."
|
| 627 |
+
return
|
| 628 |
+
|
| 629 |
+
except OutputParserException as e:
|
| 630 |
+
logger.error(f"Output parsing error: {str(e)}")
|
| 631 |
+
yield "Sorry, an error occurred while processing the response. Please rephrase your question and try again."
|
| 632 |
+
return
|
| 633 |
+
|
| 634 |
+
except ValidationError as e:
|
| 635 |
+
logger.error(f"Validation error: {str(e)}")
|
| 636 |
+
yield "Sorry, an error occurred while validating the data. Please try again."
|
| 637 |
+
return
|
| 638 |
+
|
| 639 |
+
except ToolExecutionError as e:
|
| 640 |
+
logger.error(f"Tool execution error: {str(e)}")
|
| 641 |
+
yield "Sorry, an error occurred while executing one of the operations. Please try again or contact technical support."
|
| 642 |
+
return
|
| 643 |
+
|
| 644 |
+
except Exception as e:
|
| 645 |
+
logger.error(f"Unexpected error in run_agent_streaming: {str(e)}")
|
| 646 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 647 |
+
# Log error to LangSmith
|
| 648 |
+
try:
|
| 649 |
+
log_to_langsmith(
|
| 650 |
+
key="error_log",
|
| 651 |
+
value={
|
| 652 |
+
"error": str(e),
|
| 653 |
+
"error_type": type(e).__name__,
|
| 654 |
+
**session_metadata,
|
| 655 |
+
},
|
| 656 |
+
run_id=current_run_id,
|
| 657 |
+
)
|
| 658 |
+
except Exception:
|
| 659 |
+
pass
|
| 660 |
+
|
| 661 |
+
# For unexpected errors, don't retry
|
| 662 |
+
yield "Sorry, an unexpected error occurred. Please try again or contact technical support if the problem persists."
|
| 663 |
+
return
|
| 664 |
+
|
| 665 |
+
# This should never be reached, but just in case
|
| 666 |
+
logger.error(f"Maximum retries exceeded. Last error: {str(last_error)}")
|
| 667 |
+
yield "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
async def safe_run_agent_streaming(user_input: str, session_id: str = "default") -> AsyncGenerator[str, None]:
|
| 671 |
+
"""
|
| 672 |
+
Streaming wrapper function with additional safety checks and input validation.
|
| 673 |
+
|
| 674 |
+
This function provides an additional layer of safety by validating input parameters,
|
| 675 |
+
checking input length constraints, and handling any critical errors that might
|
| 676 |
+
occur during streaming agent execution.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
user_input (str): The user's input message to process
|
| 680 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 681 |
+
|
| 682 |
+
Yields:
|
| 683 |
+
str: Chunks of the agent's response as they are generated
|
| 684 |
+
|
| 685 |
+
Raises:
|
| 686 |
+
None: All exceptions are caught and handled internally
|
| 687 |
+
"""
|
| 688 |
+
try:
|
| 689 |
+
# Input type validation
|
| 690 |
+
if not isinstance(user_input, str):
|
| 691 |
+
logger.warning(f"Invalid input type received: {type(user_input)}")
|
| 692 |
+
yield "Sorry, the input must be valid text."
|
| 693 |
+
return
|
| 694 |
+
|
| 695 |
+
# Input length validation
|
| 696 |
+
stripped_input = user_input.strip()
|
| 697 |
+
|
| 698 |
+
if len(stripped_input) > 1000:
|
| 699 |
+
logger.warning(f"Input too long: {len(stripped_input)} characters")
|
| 700 |
+
yield "Sorry, the message is too long. Please shorten your question."
|
| 701 |
+
return
|
| 702 |
+
|
| 703 |
+
if len(stripped_input) == 0:
|
| 704 |
+
logger.warning("Empty input after stripping")
|
| 705 |
+
yield "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 706 |
+
return
|
| 707 |
+
|
| 708 |
+
# Stream the response through the main agent function
|
| 709 |
+
async for chunk in run_agent_streaming(user_input, session_id):
|
| 710 |
+
yield chunk
|
| 711 |
+
|
| 712 |
+
except Exception as e:
|
| 713 |
+
logger.critical(f"Critical error in safe_run_agent_streaming: {str(e)}")
|
| 714 |
+
logger.critical(f"Traceback: {traceback.format_exc()}")
|
| 715 |
+
yield "Sorry, a critical system error occurred. Please contact technical support immediately."
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
@traceable(name="run_agent")
|
| 719 |
+
async def run_agent(user_input: str, session_id: str = "default", max_retries: int = 3) -> str:
|
| 720 |
+
"""
|
| 721 |
+
Run the agent with comprehensive error handling and retry logic.
|
| 722 |
+
|
| 723 |
+
This function processes user input through the agent executor with robust
|
| 724 |
+
error handling, automatic retries for recoverable errors, and comprehensive
|
| 725 |
+
logging for debugging and monitoring.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
user_input (str): The user's input message to process
|
| 729 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 730 |
+
max_retries (int, optional): Maximum number of retries for recoverable errors.
|
| 731 |
+
Defaults to 3.
|
| 732 |
+
|
| 733 |
+
Returns:
|
| 734 |
+
str: The agent's response or an appropriate error message in English
|
| 735 |
+
|
| 736 |
+
Raises:
|
| 737 |
+
None: All exceptions are caught and handled internally
|
| 738 |
+
"""
|
| 739 |
+
# Input validation
|
| 740 |
+
if not user_input or not user_input.strip():
|
| 741 |
+
logger.warning("Empty input received")
|
| 742 |
+
return "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 743 |
+
|
| 744 |
+
retry_count = 0
|
| 745 |
+
last_error = None
|
| 746 |
+
current_run_id = None
|
| 747 |
+
session_metadata = conversation_tracker.get_session_metadata(increment=True)
|
| 748 |
+
|
| 749 |
+
while retry_count <= max_retries:
|
| 750 |
+
try:
|
| 751 |
+
# Load conversation history from session-specific memory
|
| 752 |
+
memory = _memory_manager.get_memory(session_id)
|
| 753 |
+
chat_history = memory.load_memory_variables({})["chat_history"]
|
| 754 |
+
|
| 755 |
+
logger.info(f"Processing user input (attempt {retry_count + 1}): {user_input[:50]}...")
|
| 756 |
+
|
| 757 |
+
# Invoke the agent with input and history (synchronous call)
|
| 758 |
+
response = get_agent_executor().invoke({
|
| 759 |
+
"input": user_input.strip(),
|
| 760 |
+
"chat_history": chat_history
|
| 761 |
+
})
|
| 762 |
+
current_run_id = None # This will be handled by LangChain's tracer
|
| 763 |
+
|
| 764 |
+
# Validate response structure
|
| 765 |
+
if not response or "output" not in response or not isinstance(response["output"], str):
|
| 766 |
+
raise ValidationError("Invalid response format from agent")
|
| 767 |
+
|
| 768 |
+
if not response["output"] or not response["output"].strip():
|
| 769 |
+
raise ValidationError("Empty response from agent")
|
| 770 |
+
|
| 771 |
+
# Save conversation context to memory
|
| 772 |
+
memory.save_context(
|
| 773 |
+
{"input": user_input},
|
| 774 |
+
{"output": response["output"]}
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# Log response metrics
|
| 778 |
+
try:
|
| 779 |
+
log_to_langsmith(
|
| 780 |
+
key="response_metrics",
|
| 781 |
+
value={
|
| 782 |
+
"response_length": len(response.get("output", "")),
|
| 783 |
+
"attempt": retry_count + 1,
|
| 784 |
+
**session_metadata,
|
| 785 |
+
},
|
| 786 |
+
run_id=current_run_id,
|
| 787 |
+
)
|
| 788 |
+
except Exception:
|
| 789 |
+
pass
|
| 790 |
+
|
| 791 |
+
logger.info(f"Successfully processed user input: {user_input[:50]}...")
|
| 792 |
+
|
| 793 |
+
# Perform automatic validation in background (hidden from user)
|
| 794 |
+
final_response = response["output"]
|
| 795 |
+
if _should_validate_response(user_input, final_response):
|
| 796 |
+
logger.info("Performing background validation...")
|
| 797 |
+
try:
|
| 798 |
+
# Run validation silently - results saved to backend/GitHub only
|
| 799 |
+
_perform_automatic_validation(user_input, final_response)
|
| 800 |
+
except Exception as e:
|
| 801 |
+
logger.error(f"Background validation failed: {e}")
|
| 802 |
+
|
| 803 |
+
return final_response
|
| 804 |
+
|
| 805 |
+
except RateLimitError as e:
|
| 806 |
+
retry_count += 1
|
| 807 |
+
last_error = e
|
| 808 |
+
wait_time = min(2 ** retry_count, 60) # Exponential backoff, max 60 seconds
|
| 809 |
+
|
| 810 |
+
logger.warning(
|
| 811 |
+
f"Rate limit exceeded. Retrying in {wait_time} seconds... "
|
| 812 |
+
f"(Attempt {retry_count}/{max_retries})"
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
if retry_count <= max_retries:
|
| 816 |
+
await asyncio.sleep(wait_time)
|
| 817 |
+
continue
|
| 818 |
+
else:
|
| 819 |
+
logger.error("Rate limit exceeded after maximum retries")
|
| 820 |
+
return "Sorry, the system is currently busy. Please try again in a little while."
|
| 821 |
+
|
| 822 |
+
except APIError as e:
|
| 823 |
+
retry_count += 1
|
| 824 |
+
last_error = e
|
| 825 |
+
logger.error(f"OpenAI API error: {str(e)}")
|
| 826 |
+
|
| 827 |
+
if retry_count <= max_retries:
|
| 828 |
+
await asyncio.sleep(2)
|
| 829 |
+
continue
|
| 830 |
+
else:
|
| 831 |
+
return "Sorry, there was an error connecting to the service. Please try again later."
|
| 832 |
+
|
| 833 |
+
except requests.exceptions.ConnectionError as e:
|
| 834 |
+
retry_count += 1
|
| 835 |
+
last_error = e
|
| 836 |
+
logger.error(f"Network connection error: {str(e)}")
|
| 837 |
+
|
| 838 |
+
if retry_count <= max_retries:
|
| 839 |
+
await asyncio.sleep(3)
|
| 840 |
+
continue
|
| 841 |
+
else:
|
| 842 |
+
return "Sorry, I can't connect to the service right now. Please check your internet connection and try again."
|
| 843 |
+
|
| 844 |
+
except requests.exceptions.Timeout as e:
|
| 845 |
+
retry_count += 1
|
| 846 |
+
last_error = e
|
| 847 |
+
logger.error(f"Request timeout: {str(e)}")
|
| 848 |
+
|
| 849 |
+
if retry_count <= max_retries:
|
| 850 |
+
await asyncio.sleep(2)
|
| 851 |
+
continue
|
| 852 |
+
else:
|
| 853 |
+
return "Sorry, the request took longer than expected. Please try again."
|
| 854 |
+
|
| 855 |
+
except requests.exceptions.RequestException as e:
|
| 856 |
+
logger.error(f"Request error: {str(e)}")
|
| 857 |
+
return "Sorry, an error occurred with the request. Please try again."
|
| 858 |
+
|
| 859 |
+
except OutputParserException as e:
|
| 860 |
+
logger.error(f"Output parsing error: {str(e)}")
|
| 861 |
+
return "Sorry, an error occurred while processing the response. Please rephrase your question and try again."
|
| 862 |
+
|
| 863 |
+
except ValidationError as e:
|
| 864 |
+
logger.error(f"Validation error: {str(e)}")
|
| 865 |
+
return "Sorry, an error occurred while validating the data. Please try again."
|
| 866 |
+
|
| 867 |
+
except ToolExecutionError as e:
|
| 868 |
+
logger.error(f"Tool execution error: {str(e)}")
|
| 869 |
+
return "Sorry, an error occurred while executing one of the operations. Please try again or contact technical support."
|
| 870 |
+
|
| 871 |
+
except Exception as e:
|
| 872 |
+
logger.error(f"Unexpected error in run_agent: {str(e)}")
|
| 873 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 874 |
+
# Log error
|
| 875 |
+
try:
|
| 876 |
+
log_to_langsmith(
|
| 877 |
+
key="error_log",
|
| 878 |
+
value={
|
| 879 |
+
"error": str(e),
|
| 880 |
+
"error_type": type(e).__name__,
|
| 881 |
+
**session_metadata,
|
| 882 |
+
},
|
| 883 |
+
run_id=current_run_id,
|
| 884 |
+
)
|
| 885 |
+
except Exception:
|
| 886 |
+
pass
|
| 887 |
+
|
| 888 |
+
# For unexpected errors, don't retry
|
| 889 |
+
return "Sorry, an unexpected error occurred. Please try again or contact technical support if the problem persists."
|
| 890 |
+
|
| 891 |
+
# This should never be reached, but just in case
|
| 892 |
+
logger.error(f"Maximum retries exceeded. Last error: {str(last_error)}")
|
| 893 |
+
return "Sorry, I was unable to process your request after several attempts. Please try again later."
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
async def safe_run_agent(user_input: str, session_id: str = "default") -> str:
|
| 897 |
+
"""
|
| 898 |
+
Wrapper function for run_agent with additional safety checks and input validation.
|
| 899 |
+
|
| 900 |
+
This function provides an additional layer of safety by validating input parameters,
|
| 901 |
+
checking input length constraints, and handling any critical errors that might
|
| 902 |
+
occur during agent execution.
|
| 903 |
+
|
| 904 |
+
Args:
|
| 905 |
+
user_input (str): The user's input message to process
|
| 906 |
+
session_id (str, optional): Session identifier for conversation memory. Defaults to "default".
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
str: The agent's response or an appropriate error message in English
|
| 910 |
+
|
| 911 |
+
Raises:
|
| 912 |
+
None: All exceptions are caught and handled internally
|
| 913 |
+
"""
|
| 914 |
+
try:
|
| 915 |
+
# Input type validation
|
| 916 |
+
if not isinstance(user_input, str):
|
| 917 |
+
logger.warning(f"Invalid input type received: {type(user_input)}")
|
| 918 |
+
return "Sorry, the input must be valid text."
|
| 919 |
+
|
| 920 |
+
# Input length validation
|
| 921 |
+
stripped_input = user_input.strip()
|
| 922 |
+
|
| 923 |
+
# if len(stripped_input) > 1000:
|
| 924 |
+
# logger.warning(f"Input too long: {len(stripped_input)} characters")
|
| 925 |
+
# return "Sorry, the message is too long. Please shorten your question."
|
| 926 |
+
|
| 927 |
+
if len(stripped_input) == 0:
|
| 928 |
+
logger.warning("Empty input after stripping")
|
| 929 |
+
return "Sorry, I didn't receive any questions. Please enter your question or request."
|
| 930 |
+
|
| 931 |
+
# Process the input through the main agent function
|
| 932 |
+
return await run_agent(user_input, session_id)
|
| 933 |
+
|
| 934 |
+
except Exception as e:
|
| 935 |
+
logger.critical(f"Critical error in safe_run_agent: {str(e)}")
|
| 936 |
+
logger.critical(f"Traceback: {traceback.format_exc()}")
|
| 937 |
+
return "Sorry, a critical system error occurred. Please contact technical support immediately."
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def clear_memory() -> None:
|
| 941 |
+
"""
|
| 942 |
+
Clear the conversation memory.
|
| 943 |
+
|
| 944 |
+
This function clears all stored conversation history from memory,
|
| 945 |
+
effectively starting a fresh conversation session.
|
| 946 |
+
"""
|
| 947 |
+
try:
|
| 948 |
+
_memory_manager.clear_all_sessions()
|
| 949 |
+
logger.info("Conversation memory cleared successfully")
|
| 950 |
+
except Exception as e:
|
| 951 |
+
logger.error(f"Error clearing memory: {str(e)}")
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def get_memory_summary(session_id: str = "default") -> str:
|
| 955 |
+
"""
|
| 956 |
+
Get a summary of the conversation history for a specific session.
|
| 957 |
+
|
| 958 |
+
Args:
|
| 959 |
+
session_id (str, optional): Session identifier. Defaults to "default".
|
| 960 |
+
|
| 961 |
+
Returns:
|
| 962 |
+
str: A summary of the conversation history stored in memory
|
| 963 |
+
"""
|
| 964 |
+
try:
|
| 965 |
+
memory = _memory_manager.get_memory(session_id)
|
| 966 |
+
memory_vars = memory.load_memory_variables({})
|
| 967 |
+
return str(memory_vars.get("chat_history", "No conversation history available"))
|
| 968 |
+
except Exception as e:
|
| 969 |
+
logger.error(f"Error getting memory summary: {str(e)}")
|
| 970 |
+
return "Error retrieving conversation history"
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
def clear_session_memory(session_id: str) -> bool:
|
| 974 |
+
"""
|
| 975 |
+
Clear conversation memory for a specific session.
|
| 976 |
+
|
| 977 |
+
Args:
|
| 978 |
+
session_id (str): Session identifier to clear
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
bool: True if session was cleared, False if session didn't exist
|
| 982 |
+
"""
|
| 983 |
+
return _memory_manager.clear_session(session_id)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
def get_active_sessions() -> list:
|
| 987 |
+
"""
|
| 988 |
+
Get list of all active session IDs.
|
| 989 |
+
|
| 990 |
+
Returns:
|
| 991 |
+
list: List of active session identifiers
|
| 992 |
+
"""
|
| 993 |
+
return _memory_manager.get_active_sessions()
|
core/background_init.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Background initialization system for preloading heavy components during app startup.
|
| 3 |
+
This module handles the eager loading of embedding models, retrievers, and chunks
|
| 4 |
+
to improve first-question response time.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
from typing import Optional, Callable
|
| 10 |
+
from .config import logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BackgroundInitializer:
|
| 14 |
+
"""Manages background initialization of heavy components"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self._initialization_thread: Optional[threading.Thread] = None
|
| 18 |
+
self._initialization_complete = threading.Event()
|
| 19 |
+
self._initialization_error: Optional[Exception] = None
|
| 20 |
+
self._progress_callback: Optional[Callable[[str, int], None]] = None
|
| 21 |
+
self._status = "Not started"
|
| 22 |
+
|
| 23 |
+
def set_progress_callback(self, callback: Callable[[str, int], None]):
|
| 24 |
+
"""Set a callback function to receive progress updates"""
|
| 25 |
+
self._progress_callback = callback
|
| 26 |
+
|
| 27 |
+
def _update_progress(self, message: str, percentage: int):
|
| 28 |
+
"""Update progress and call callback if set"""
|
| 29 |
+
self._status = message
|
| 30 |
+
logger.info(f"🔄 Background Init: {message} ({percentage}%)")
|
| 31 |
+
if self._progress_callback:
|
| 32 |
+
try:
|
| 33 |
+
self._progress_callback(message, percentage)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"Progress callback error: {e}")
|
| 36 |
+
|
| 37 |
+
def _initialize_components(self):
|
| 38 |
+
"""Initialize all heavy components in background thread"""
|
| 39 |
+
try:
|
| 40 |
+
self._update_progress("Starting background initialization...", 0)
|
| 41 |
+
|
| 42 |
+
# Step 1: Load embedding model (this is the heaviest component)
|
| 43 |
+
self._update_progress("Loading embedding model...", 10)
|
| 44 |
+
from .config import get_embedding_model
|
| 45 |
+
embedding_model = get_embedding_model()
|
| 46 |
+
self._update_progress("Embedding model loaded successfully", 40)
|
| 47 |
+
|
| 48 |
+
# Step 2: Initialize retrievers (this will load chunks and create vector store)
|
| 49 |
+
self._update_progress("Initializing retrievers and loading chunks...", 50)
|
| 50 |
+
from .retrievers import _ensure_initialized
|
| 51 |
+
_ensure_initialized()
|
| 52 |
+
self._update_progress("Retrievers initialized successfully", 90)
|
| 53 |
+
|
| 54 |
+
# Step 3: Learn medical terminology from corpus
|
| 55 |
+
self._update_progress("Learning medical terminology from corpus...", 92)
|
| 56 |
+
try:
|
| 57 |
+
from .medical_terminology import learn_from_corpus
|
| 58 |
+
from . import utils
|
| 59 |
+
|
| 60 |
+
# Load chunks to learn from
|
| 61 |
+
chunks = utils.load_chunks()
|
| 62 |
+
if chunks:
|
| 63 |
+
# Convert to format expected by learner
|
| 64 |
+
documents = [{'content': chunk.page_content} for chunk in chunks[:1000]] # Limit for performance
|
| 65 |
+
learn_from_corpus(documents)
|
| 66 |
+
logger.info(f"Learned medical terminology from {len(documents)} documents")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(f"Could not learn terminology from corpus: {e}")
|
| 69 |
+
|
| 70 |
+
# Step 4: Warm up LLM (optional, lightweight)
|
| 71 |
+
self._update_progress("Warming up LLM...", 97)
|
| 72 |
+
from .config import get_llm
|
| 73 |
+
llm = get_llm()
|
| 74 |
+
self._update_progress("All components initialized successfully", 100)
|
| 75 |
+
|
| 76 |
+
logger.info("✅ Background initialization completed successfully")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self._initialization_error = e
|
| 80 |
+
logger.error(f"❌ Background initialization failed: {e}")
|
| 81 |
+
self._update_progress(f"Initialization failed: {str(e)}", -1)
|
| 82 |
+
finally:
|
| 83 |
+
self._initialization_complete.set()
|
| 84 |
+
|
| 85 |
+
def start_background_initialization(self):
|
| 86 |
+
"""Start background initialization in a separate thread"""
|
| 87 |
+
if self._initialization_thread is not None:
|
| 88 |
+
logger.warning("Background initialization already started")
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
logger.info("🚀 Starting background initialization...")
|
| 92 |
+
self._initialization_thread = threading.Thread(
|
| 93 |
+
target=self._initialize_components,
|
| 94 |
+
name="BackgroundInitializer",
|
| 95 |
+
daemon=True
|
| 96 |
+
)
|
| 97 |
+
self._initialization_thread.start()
|
| 98 |
+
|
| 99 |
+
def is_complete(self) -> bool:
|
| 100 |
+
"""Check if initialization is complete"""
|
| 101 |
+
return self._initialization_complete.is_set()
|
| 102 |
+
|
| 103 |
+
def wait_for_completion(self, timeout: Optional[float] = None) -> bool:
|
| 104 |
+
"""Wait for initialization to complete"""
|
| 105 |
+
return self._initialization_complete.wait(timeout)
|
| 106 |
+
|
| 107 |
+
def get_status(self) -> str:
|
| 108 |
+
"""Get current initialization status"""
|
| 109 |
+
return self._status
|
| 110 |
+
|
| 111 |
+
def get_error(self) -> Optional[Exception]:
|
| 112 |
+
"""Get initialization error if any"""
|
| 113 |
+
return self._initialization_error
|
| 114 |
+
|
| 115 |
+
def is_successful(self) -> bool:
|
| 116 |
+
"""Check if initialization completed successfully"""
|
| 117 |
+
return self.is_complete() and self._initialization_error is None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Global initializer instance
|
| 121 |
+
_background_initializer = BackgroundInitializer()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def start_background_initialization(progress_callback: Optional[Callable[[str, int], None]] = None):
|
| 125 |
+
"""Start background initialization with optional progress callback"""
|
| 126 |
+
if progress_callback:
|
| 127 |
+
_background_initializer.set_progress_callback(progress_callback)
|
| 128 |
+
_background_initializer.start_background_initialization()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def wait_for_initialization(timeout: Optional[float] = None) -> bool:
|
| 132 |
+
"""Wait for background initialization to complete"""
|
| 133 |
+
return _background_initializer.wait_for_completion(timeout)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def is_initialization_complete() -> bool:
|
| 137 |
+
"""Check if background initialization is complete"""
|
| 138 |
+
return _background_initializer.is_complete()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_initialization_status() -> str:
|
| 142 |
+
"""Get current initialization status"""
|
| 143 |
+
return _background_initializer.get_status()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_initialization_error() -> Optional[Exception]:
|
| 147 |
+
"""Get initialization error if any"""
|
| 148 |
+
return _background_initializer.get_error()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def is_initialization_successful() -> bool:
|
| 152 |
+
"""Check if initialization completed successfully"""
|
| 153 |
+
return _background_initializer.is_successful()
|
core/config.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import logging
|
| 7 |
+
from logging.handlers import RotatingFileHandler
|
| 8 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 9 |
+
|
| 10 |
+
# Initialize environment
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# --- Settings (simple, in-file) ---
|
| 15 |
+
class Settings(BaseSettings):
|
| 16 |
+
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra='ignore')
|
| 17 |
+
|
| 18 |
+
OPENAI_API_KEY: str
|
| 19 |
+
OPENAI_BASE_URL: str | None = None
|
| 20 |
+
|
| 21 |
+
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
| 22 |
+
DATA_DIR: str = os.getenv("DATA_DIR", "")
|
| 23 |
+
LOG_DIR: str = os.getenv("LOG_DIR", "")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
settings = Settings()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# --- File Path Configuration (Cross-platform compatible) ---
|
| 30 |
+
PROJECT_ROOT = Path(__file__).parent.parent.absolute()
|
| 31 |
+
DATA_DIR = Path(settings.DATA_DIR or (PROJECT_ROOT / "data"))
|
| 32 |
+
NEW_DATA = DATA_DIR / "new_data"
|
| 33 |
+
PROCESSED_DATA = DATA_DIR / "processed_data"
|
| 34 |
+
CHUNKS_PATH = DATA_DIR / "chunks.pkl"
|
| 35 |
+
VECTOR_STORE_DIR = DATA_DIR / "vector_store"
|
| 36 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
NEW_DATA.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
PROCESSED_DATA.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
VECTOR_STORE_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
# Setup logging
|
| 42 |
+
LOG_DIR = Path(settings.LOG_DIR or (Path(__file__).parent.parent / "logs"))
|
| 43 |
+
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
LOG_FILE = LOG_DIR / "app.log"
|
| 46 |
+
|
| 47 |
+
# Configure application logger (avoid duplicate handlers)
|
| 48 |
+
LOG_LEVEL = settings.LOG_LEVEL.upper()
|
| 49 |
+
logger = logging.getLogger("AgenticMedicalRAG") # centralized logger
|
| 50 |
+
logger.setLevel(LOG_LEVEL)
|
| 51 |
+
logger.propagate = False
|
| 52 |
+
if not logger.handlers:
|
| 53 |
+
formatter = logging.Formatter(
|
| 54 |
+
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 55 |
+
)
|
| 56 |
+
file_handler = RotatingFileHandler(
|
| 57 |
+
LOG_FILE,
|
| 58 |
+
maxBytes=1000000,
|
| 59 |
+
backupCount=3,
|
| 60 |
+
encoding="utf-8"
|
| 61 |
+
)
|
| 62 |
+
file_handler.setFormatter(formatter)
|
| 63 |
+
stream_handler = logging.StreamHandler()
|
| 64 |
+
stream_handler.setFormatter(formatter)
|
| 65 |
+
logger.addHandler(file_handler)
|
| 66 |
+
logger.addHandler(stream_handler)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --- LLM Configuration with lazy loading ---
|
| 71 |
+
_llm = None
|
| 72 |
+
|
| 73 |
+
def get_llm():
|
| 74 |
+
"""Get LLM with lazy loading for faster startup"""
|
| 75 |
+
global _llm
|
| 76 |
+
if _llm is None:
|
| 77 |
+
logger.info("Initializing LLM (first time)...")
|
| 78 |
+
openai_key = settings.OPENAI_API_KEY
|
| 79 |
+
|
| 80 |
+
if not openai_key:
|
| 81 |
+
logger.error("OPENAI_API_KEY not found in environment variables")
|
| 82 |
+
raise ValueError("OpenAI API key is required. Please set OPENAI_API_KEY environment variable.")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
_llm = ChatOpenAI(
|
| 86 |
+
model="gpt-4o",
|
| 87 |
+
api_key=openai_key,
|
| 88 |
+
base_url=settings.OPENAI_BASE_URL,
|
| 89 |
+
temperature=0.0,
|
| 90 |
+
max_tokens=2048,
|
| 91 |
+
request_timeout=30, # Increased timeout for stability
|
| 92 |
+
max_retries=2,
|
| 93 |
+
streaming=True,
|
| 94 |
+
)
|
| 95 |
+
logger.info("LLM initialized successfully")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Failed to initialize LLM: {e}")
|
| 98 |
+
raise
|
| 99 |
+
return _llm
|
| 100 |
+
|
| 101 |
+
def create_llm():
|
| 102 |
+
"""Create LLM with proper error handling and fallbacks"""
|
| 103 |
+
return get_llm()
|
| 104 |
+
|
| 105 |
+
# Lazy loading - only initialize when actually needed
|
| 106 |
+
LLM = None # Will be loaded on first use
|
| 107 |
+
|
| 108 |
+
# --- Embedding Model Configuration with lazy loading ---
|
| 109 |
+
_embedding_model = None
|
| 110 |
+
|
| 111 |
+
def get_embedding_model():
|
| 112 |
+
"""Get embedding model with lazy loading for faster startup"""
|
| 113 |
+
global _embedding_model
|
| 114 |
+
if _embedding_model is None:
|
| 115 |
+
logger.info("Loading embedding model (first time)...")
|
| 116 |
+
try:
|
| 117 |
+
_embedding_model = HuggingFaceEmbeddings(
|
| 118 |
+
model_name="abhinand/MedEmbed-base-v0.1",
|
| 119 |
+
model_kwargs={'device': 'cpu'},
|
| 120 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 121 |
+
)
|
| 122 |
+
logger.info("Embedding model loaded successfully")
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Failed to load embedding model: {e}")
|
| 125 |
+
raise ValueError("Failed to load embedding model")
|
| 126 |
+
return _embedding_model
|
| 127 |
+
|
| 128 |
+
# For backward compatibility
|
| 129 |
+
def create_embedding_model():
|
| 130 |
+
"""Create embedding model with proper error handling"""
|
| 131 |
+
return get_embedding_model()
|
| 132 |
+
|
| 133 |
+
# Lazy loading - only load when actually needed
|
| 134 |
+
EMBEDDING_MODEL = None # Will be loaded on first use
|
| 135 |
+
|
| 136 |
+
# Configuration validation
|
| 137 |
+
def validate_config():
|
| 138 |
+
"""Validate all required configurations"""
|
| 139 |
+
required_env_vars = ["OPENAI_API_KEY"]
|
| 140 |
+
missing_vars = [var for var in required_env_vars if not getattr(settings, var, None)]
|
| 141 |
+
|
| 142 |
+
if missing_vars:
|
| 143 |
+
raise ValueError(f"Missing required environment variables: {missing_vars}")
|
| 144 |
+
|
| 145 |
+
logger.info("Configuration validation completed")
|
| 146 |
+
|
| 147 |
+
# Run validation on import
|
| 148 |
+
try:
|
| 149 |
+
validate_config()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Configuration validation failed: {e}")
|
| 152 |
+
raise e
|
core/context_enrichment.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Context Enrichment Module for Medical RAG
|
| 3 |
+
|
| 4 |
+
This module enriches retrieved documents with surrounding context (adjacent pages)
|
| 5 |
+
to provide comprehensive information for expert medical professionals.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Dict, Set, Optional
|
| 9 |
+
from langchain.schema import Document
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from .config import logger
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ContextEnricher:
|
| 15 |
+
"""
|
| 16 |
+
Enriches retrieved documents with surrounding pages for richer context.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, cache_size: int = 100):
|
| 20 |
+
"""
|
| 21 |
+
Initialize context enricher with document cache.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
cache_size: Maximum number of source documents to cache
|
| 25 |
+
"""
|
| 26 |
+
self._document_cache: Dict[str, List[Document]] = {}
|
| 27 |
+
self._cache_size = cache_size
|
| 28 |
+
self._all_chunks_cache: Optional[List[Document]] = None # Cache all chunks to avoid reloading
|
| 29 |
+
|
| 30 |
+
def enrich_documents(
|
| 31 |
+
self,
|
| 32 |
+
retrieved_docs: List[Document],
|
| 33 |
+
pages_before: int = 1,
|
| 34 |
+
pages_after: int = 1,
|
| 35 |
+
max_enriched_docs: int = 5
|
| 36 |
+
) -> List[Document]:
|
| 37 |
+
"""
|
| 38 |
+
Enrich retrieved documents by adding separate context pages.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
retrieved_docs: List of retrieved documents
|
| 42 |
+
pages_before: Number of pages to include before each document
|
| 43 |
+
pages_after: Number of pages to include after each document
|
| 44 |
+
max_enriched_docs: Maximum number of documents to enrich (top results)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
List with original documents + separate context page documents
|
| 48 |
+
"""
|
| 49 |
+
if not retrieved_docs:
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
result_docs = []
|
| 53 |
+
processed_sources = set()
|
| 54 |
+
enriched_count = 0
|
| 55 |
+
|
| 56 |
+
# Only enrich top documents to avoid overwhelming context
|
| 57 |
+
docs_to_enrich = retrieved_docs[:max_enriched_docs]
|
| 58 |
+
|
| 59 |
+
for doc in docs_to_enrich:
|
| 60 |
+
try:
|
| 61 |
+
# Get source information
|
| 62 |
+
source = doc.metadata.get('source', 'unknown')
|
| 63 |
+
page_num = doc.metadata.get('page_number', 1)
|
| 64 |
+
|
| 65 |
+
# Skip if already processed this source-page combination
|
| 66 |
+
source_page_key = f"{source}_{page_num}"
|
| 67 |
+
if source_page_key in processed_sources:
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
processed_sources.add(source_page_key)
|
| 71 |
+
|
| 72 |
+
# Get surrounding pages
|
| 73 |
+
surrounding_docs = self._get_surrounding_pages(
|
| 74 |
+
doc,
|
| 75 |
+
pages_before,
|
| 76 |
+
pages_after
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
if surrounding_docs:
|
| 80 |
+
# Add separate documents for each page
|
| 81 |
+
page_docs = self._create_separate_page_documents(
|
| 82 |
+
doc,
|
| 83 |
+
surrounding_docs,
|
| 84 |
+
pages_before,
|
| 85 |
+
pages_after
|
| 86 |
+
)
|
| 87 |
+
result_docs.extend(page_docs)
|
| 88 |
+
enriched_count += 1
|
| 89 |
+
|
| 90 |
+
# Log enrichment details
|
| 91 |
+
page_numbers = [int(d.metadata.get('page_number', 0)) for d in page_docs]
|
| 92 |
+
logger.debug(f"Enriched {source} page {page_num} with pages: {page_numbers}")
|
| 93 |
+
else:
|
| 94 |
+
# No surrounding pages found, add original with empty enrichment metadata
|
| 95 |
+
original_with_metadata = self._add_empty_enrichment_metadata(doc)
|
| 96 |
+
result_docs.append(original_with_metadata)
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.warning(f"Could not enrich document from {doc.metadata.get('source')}: {e}")
|
| 100 |
+
original_with_metadata = self._add_empty_enrichment_metadata(doc)
|
| 101 |
+
result_docs.append(original_with_metadata)
|
| 102 |
+
|
| 103 |
+
# Add remaining documents without enrichment
|
| 104 |
+
for doc in retrieved_docs[max_enriched_docs:]:
|
| 105 |
+
original_with_metadata = self._add_empty_enrichment_metadata(doc)
|
| 106 |
+
result_docs.append(original_with_metadata)
|
| 107 |
+
|
| 108 |
+
logger.info(f"Enriched {enriched_count} documents with surrounding context pages")
|
| 109 |
+
return result_docs
|
| 110 |
+
|
| 111 |
+
def _get_surrounding_pages(
|
| 112 |
+
self,
|
| 113 |
+
doc: Document,
|
| 114 |
+
pages_before: int,
|
| 115 |
+
pages_after: int
|
| 116 |
+
) -> List[Document]:
|
| 117 |
+
"""
|
| 118 |
+
Get surrounding pages for a document.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
doc: Original document
|
| 122 |
+
pages_before: Number of pages before
|
| 123 |
+
pages_after: Number of pages after
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List of surrounding documents (including original), deduplicated by page number
|
| 127 |
+
"""
|
| 128 |
+
source = doc.metadata.get('source', 'unknown')
|
| 129 |
+
page_num = doc.metadata.get('page_number', 1)
|
| 130 |
+
provider = doc.metadata.get('provider', 'unknown')
|
| 131 |
+
disease = doc.metadata.get('disease', 'unknown')
|
| 132 |
+
|
| 133 |
+
# Try to get full document from cache or load it
|
| 134 |
+
full_doc_pages = self._get_full_document(source, provider, disease)
|
| 135 |
+
|
| 136 |
+
if not full_doc_pages:
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
# Find the target page and surrounding pages
|
| 140 |
+
target_page = int(page_num) if isinstance(page_num, (int, str)) else 1
|
| 141 |
+
|
| 142 |
+
# Use a dict to deduplicate by page number (keep first occurrence)
|
| 143 |
+
pages_dict = {}
|
| 144 |
+
|
| 145 |
+
for page_doc in full_doc_pages:
|
| 146 |
+
doc_page_num = page_doc.metadata.get('page_number', 0)
|
| 147 |
+
if isinstance(doc_page_num, str):
|
| 148 |
+
try:
|
| 149 |
+
doc_page_num = int(doc_page_num)
|
| 150 |
+
except:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
# Include pages within range
|
| 154 |
+
if target_page - pages_before <= doc_page_num <= target_page + pages_after:
|
| 155 |
+
# Only add if not already present (deduplication)
|
| 156 |
+
if doc_page_num not in pages_dict:
|
| 157 |
+
pages_dict[doc_page_num] = page_doc
|
| 158 |
+
|
| 159 |
+
# Return sorted by page number
|
| 160 |
+
surrounding = [pages_dict[pn] for pn in sorted(pages_dict.keys())]
|
| 161 |
+
|
| 162 |
+
return surrounding
|
| 163 |
+
|
| 164 |
+
def _get_full_document(
|
| 165 |
+
self,
|
| 166 |
+
source: str,
|
| 167 |
+
provider: str,
|
| 168 |
+
disease: str
|
| 169 |
+
) -> Optional[List[Document]]:
|
| 170 |
+
"""
|
| 171 |
+
Get full document pages from chunks cache.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
source: Source filename
|
| 175 |
+
provider: Provider name
|
| 176 |
+
disease: Disease name
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
List of all pages in the document, or None if not found
|
| 180 |
+
"""
|
| 181 |
+
cache_key = f"{provider}_{disease}_{source}"
|
| 182 |
+
|
| 183 |
+
# Check cache
|
| 184 |
+
if cache_key in self._document_cache:
|
| 185 |
+
return self._document_cache[cache_key]
|
| 186 |
+
|
| 187 |
+
# Load from chunks cache instead of trying to reload PDFs
|
| 188 |
+
try:
|
| 189 |
+
from . import utils
|
| 190 |
+
|
| 191 |
+
# Load all chunks (use cached version to avoid redundant loading)
|
| 192 |
+
if self._all_chunks_cache is None:
|
| 193 |
+
self._all_chunks_cache = utils.load_chunks()
|
| 194 |
+
if self._all_chunks_cache:
|
| 195 |
+
logger.debug(f"Loaded {len(self._all_chunks_cache)} chunks into enricher cache")
|
| 196 |
+
|
| 197 |
+
all_chunks = self._all_chunks_cache
|
| 198 |
+
if not all_chunks:
|
| 199 |
+
logger.debug(f"No chunks available for enrichment")
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
# Filter chunks for this specific document
|
| 203 |
+
doc_pages = []
|
| 204 |
+
for chunk in all_chunks:
|
| 205 |
+
chunk_source = chunk.metadata.get('source', '')
|
| 206 |
+
chunk_provider = chunk.metadata.get('provider', '')
|
| 207 |
+
chunk_disease = chunk.metadata.get('disease', '')
|
| 208 |
+
|
| 209 |
+
# Match by source, provider, and disease
|
| 210 |
+
if (chunk_source == source and
|
| 211 |
+
chunk_provider == provider and
|
| 212 |
+
chunk_disease == disease):
|
| 213 |
+
doc_pages.append(chunk)
|
| 214 |
+
|
| 215 |
+
if not doc_pages:
|
| 216 |
+
logger.debug(f"Could not find chunks for document: {source} (Provider: {provider}, Disease: {disease})")
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
# Sort by page number
|
| 220 |
+
doc_pages.sort(key=lambda d: int(d.metadata.get('page_number', 0)))
|
| 221 |
+
|
| 222 |
+
# Cache it (with size limit)
|
| 223 |
+
if len(self._document_cache) >= self._cache_size:
|
| 224 |
+
# Remove oldest entry
|
| 225 |
+
self._document_cache.pop(next(iter(self._document_cache)))
|
| 226 |
+
|
| 227 |
+
self._document_cache[cache_key] = doc_pages
|
| 228 |
+
logger.debug(f"Loaded {len(doc_pages)} pages for {source} from chunks cache")
|
| 229 |
+
return doc_pages
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.warning(f"Error loading document from chunks cache {source}: {e}")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
def _create_separate_page_documents(
|
| 236 |
+
self,
|
| 237 |
+
original_doc: Document,
|
| 238 |
+
surrounding_docs: List[Document],
|
| 239 |
+
pages_before: int,
|
| 240 |
+
pages_after: int
|
| 241 |
+
) -> List[Document]:
|
| 242 |
+
"""
|
| 243 |
+
Create separate document objects for original page and context pages.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
original_doc: Original retrieved document
|
| 247 |
+
surrounding_docs: List of surrounding documents
|
| 248 |
+
pages_before: Number of pages before
|
| 249 |
+
pages_after: Number of pages after
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
List of separate documents (context pages + original page + context pages)
|
| 253 |
+
"""
|
| 254 |
+
# Sort by page number
|
| 255 |
+
sorted_docs = sorted(
|
| 256 |
+
surrounding_docs,
|
| 257 |
+
key=lambda d: int(d.metadata.get('page_number', 0))
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
original_page = int(original_doc.metadata.get('page_number', 1))
|
| 261 |
+
result_docs = []
|
| 262 |
+
|
| 263 |
+
for doc in sorted_docs:
|
| 264 |
+
page_num = int(doc.metadata.get('page_number', 0))
|
| 265 |
+
|
| 266 |
+
# Determine if this is a context page or the original page
|
| 267 |
+
is_context_page = (page_num != original_page)
|
| 268 |
+
|
| 269 |
+
# Create document with appropriate metadata
|
| 270 |
+
page_doc = Document(
|
| 271 |
+
page_content=doc.page_content,
|
| 272 |
+
metadata={
|
| 273 |
+
**doc.metadata,
|
| 274 |
+
'context_enrichment': is_context_page,
|
| 275 |
+
'enriched': False,
|
| 276 |
+
'pages_included': [],
|
| 277 |
+
'primary_page': None,
|
| 278 |
+
'context_pages_before': None,
|
| 279 |
+
'context_pages_after': None,
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
result_docs.append(page_doc)
|
| 284 |
+
|
| 285 |
+
return result_docs
|
| 286 |
+
|
| 287 |
+
def _add_empty_enrichment_metadata(self, doc: Document) -> Document:
|
| 288 |
+
"""
|
| 289 |
+
Add empty enrichment metadata fields to a document.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
doc: Original document
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Document with enrichment metadata fields set to default values
|
| 296 |
+
"""
|
| 297 |
+
return Document(
|
| 298 |
+
page_content=doc.page_content,
|
| 299 |
+
metadata={
|
| 300 |
+
**doc.metadata,
|
| 301 |
+
'enriched': False,
|
| 302 |
+
'pages_included': [],
|
| 303 |
+
'primary_page': None,
|
| 304 |
+
'context_pages_before': None,
|
| 305 |
+
'context_pages_after': None,
|
| 306 |
+
}
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Global enricher instance
|
| 311 |
+
_context_enricher = ContextEnricher(cache_size=100)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def enrich_retrieved_documents(
|
| 315 |
+
documents: List[Document],
|
| 316 |
+
pages_before: int = 1,
|
| 317 |
+
pages_after: int = 1,
|
| 318 |
+
max_enriched: int = 5
|
| 319 |
+
) -> List[Document]:
|
| 320 |
+
"""
|
| 321 |
+
Convenience function to enrich retrieved documents.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
documents: Retrieved documents
|
| 325 |
+
pages_before: Number of pages to include before each document
|
| 326 |
+
pages_after: Number of pages to include after each document
|
| 327 |
+
max_enriched: Maximum number of documents to enrich
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
Enriched documents with surrounding context
|
| 331 |
+
"""
|
| 332 |
+
return _context_enricher.enrich_documents(
|
| 333 |
+
documents,
|
| 334 |
+
pages_before=pages_before,
|
| 335 |
+
pages_after=pages_after,
|
| 336 |
+
max_enriched_docs=max_enriched
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def get_context_enricher() -> ContextEnricher:
|
| 341 |
+
"""Get the global context enricher instance."""
|
| 342 |
+
return _context_enricher
|
core/data_loaders.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import required libraries
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
from langchain.schema import Document
|
| 6 |
+
from .config import logger
|
| 7 |
+
from langchain_pymupdf4llm import PyMuPDF4LLMLoader
|
| 8 |
+
from langchain_community.document_loaders.parsers import TesseractBlobParser
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_pdf_documents(pdf_path: Path) -> List[Document]:
|
| 12 |
+
"""
|
| 13 |
+
Load and process PDF documents from medical guidelines using PyMuPDF4LLMLoader.
|
| 14 |
+
Uses Tesseract for image extraction and optimized table extraction for medical documents.
|
| 15 |
+
Extracts disease and provider from directory structure.
|
| 16 |
+
|
| 17 |
+
Directory structure expected: data/new_data/PROVIDER/file.pdf
|
| 18 |
+
Example: data/new_data/SASLT/SASLT_2021.pdf
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
pdf_path: Path to the PDF file
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
List of Document objects with metadata (source, disease, provider, page_number)
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
|
| 28 |
+
# Validate file exists
|
| 29 |
+
if not pdf_path.exists():
|
| 30 |
+
raise FileNotFoundError(f"PDF file not found at {pdf_path}")
|
| 31 |
+
|
| 32 |
+
# Extract provider from directory structure
|
| 33 |
+
# Structure: data/new_data/PROVIDER/file.pdf
|
| 34 |
+
path_parts = pdf_path.parts
|
| 35 |
+
disease = "HBV" # Default disease for this system
|
| 36 |
+
provider = "unknown"
|
| 37 |
+
|
| 38 |
+
# Find provider: it's the parent directory of the PDF file
|
| 39 |
+
if len(path_parts) >= 2:
|
| 40 |
+
provider = path_parts[-2] # Parent directory (e.g., SASLT)
|
| 41 |
+
|
| 42 |
+
# If provider is 'new_data', it means file is directly in new_data folder
|
| 43 |
+
if provider.lower() == "new_data":
|
| 44 |
+
provider = "unknown"
|
| 45 |
+
|
| 46 |
+
# Initialize PyMuPDF4LLMLoader
|
| 47 |
+
loader = PyMuPDF4LLMLoader(
|
| 48 |
+
str(pdf_path),
|
| 49 |
+
mode="page",
|
| 50 |
+
extract_images=True,
|
| 51 |
+
images_parser=TesseractBlobParser(),
|
| 52 |
+
table_strategy="lines"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
raw_documents = loader.load()
|
| 56 |
+
|
| 57 |
+
documents = []
|
| 58 |
+
for idx, doc in enumerate(raw_documents):
|
| 59 |
+
if doc.page_content.strip():
|
| 60 |
+
# Extract actual page number from metadata, default to sequential numbering
|
| 61 |
+
# PyMuPDF4LLMLoader uses 0-indexed pages, so we add 1 for human-readable page numbers
|
| 62 |
+
actual_page = doc.metadata.get("page")
|
| 63 |
+
if actual_page is not None:
|
| 64 |
+
# If page is 0-indexed, add 1 to make it 1-indexed
|
| 65 |
+
page_num = actual_page + 1 if actual_page == idx else actual_page
|
| 66 |
+
else:
|
| 67 |
+
# Fallback to 1-indexed sequential numbering
|
| 68 |
+
page_num = idx + 1
|
| 69 |
+
|
| 70 |
+
processed_doc = Document(
|
| 71 |
+
page_content=doc.page_content,
|
| 72 |
+
metadata={
|
| 73 |
+
"source": pdf_path.name,
|
| 74 |
+
"disease": disease,
|
| 75 |
+
"provider": provider,
|
| 76 |
+
"page_number": page_num
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
documents.append(processed_doc)
|
| 80 |
+
|
| 81 |
+
logger.info(f"Loaded {len(documents)} pages from PDF: {pdf_path.name} (Disease: {disease}, Provider: {provider})")
|
| 82 |
+
return documents
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Error loading PDF documents from {pdf_path}: {str(e)}")
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_markdown_documents(md_path: Path) -> List[Document]:
|
| 90 |
+
"""
|
| 91 |
+
Load and process Markdown medical guidelines.
|
| 92 |
+
Extracts disease and provider from directory structure.
|
| 93 |
+
|
| 94 |
+
Directory structure expected: data/new_data/PROVIDER/file.md
|
| 95 |
+
Example: data/new_data/SASLT/guidelines.md
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
md_path: Path to the Markdown file
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
List of Document objects with metadata (source, disease, provider, page_number)
|
| 102 |
+
"""
|
| 103 |
+
try:
|
| 104 |
+
# Validate file exists
|
| 105 |
+
if not md_path.exists():
|
| 106 |
+
raise FileNotFoundError(f"Markdown file not found at {md_path}")
|
| 107 |
+
|
| 108 |
+
# Extract provider from directory structure
|
| 109 |
+
# Structure: data/new_data/PROVIDER/file.md
|
| 110 |
+
path_parts = md_path.parts
|
| 111 |
+
disease = "HBV" # Default disease for this system
|
| 112 |
+
provider = "unknown"
|
| 113 |
+
|
| 114 |
+
# Find provider: it's the parent directory of the markdown file
|
| 115 |
+
if len(path_parts) >= 2:
|
| 116 |
+
provider = path_parts[-2] # Parent directory (e.g., SASLT)
|
| 117 |
+
|
| 118 |
+
# If provider is 'new_data', it means file is directly in new_data folder
|
| 119 |
+
if provider.lower() == "new_data":
|
| 120 |
+
provider = "unknown"
|
| 121 |
+
|
| 122 |
+
# Read markdown content
|
| 123 |
+
with open(md_path, 'r', encoding='utf-8') as f:
|
| 124 |
+
content = f.read()
|
| 125 |
+
|
| 126 |
+
# Create document with minimal metadata for RAG
|
| 127 |
+
doc = Document(
|
| 128 |
+
page_content=content,
|
| 129 |
+
metadata={
|
| 130 |
+
"source": md_path.name,
|
| 131 |
+
"disease": disease,
|
| 132 |
+
"provider": provider,
|
| 133 |
+
"page_number": 1
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
logger.info(f"Loaded Markdown document: {md_path.name} (Disease: {disease}, Provider: {provider})")
|
| 138 |
+
return [doc]
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error loading Markdown document from {md_path}: {str(e)}")
|
| 142 |
+
raise
|
core/github_storage.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GitHub Storage Utility for Medical RAG Advisor
|
| 3 |
+
Handles saving side effects reports and validation results to GitHub repository
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import csv
|
| 8 |
+
import io
|
| 9 |
+
import base64
|
| 10 |
+
import time
|
| 11 |
+
import traceback
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Dict, List, Any, Optional
|
| 14 |
+
import requests
|
| 15 |
+
from .config import logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GitHubStorage:
|
| 19 |
+
"""
|
| 20 |
+
Utility class for storing medical data files in GitHub repository
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, repo_url: str = "https://github.com/MoazEldsouky/HBV-AI-Assistant-data",
|
| 24 |
+
github_token: str = None):
|
| 25 |
+
"""
|
| 26 |
+
Initialize GitHub storage with repository details
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
repo_url: GitHub repository URL (default: HBV AI Assistant data repository)
|
| 30 |
+
github_token: GitHub personal access token
|
| 31 |
+
"""
|
| 32 |
+
self.repo_url = repo_url
|
| 33 |
+
self.github_token = github_token or os.getenv("GITHUB_TOKEN", "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")
|
| 34 |
+
|
| 35 |
+
# Log token status (masked for security)
|
| 36 |
+
if self.github_token:
|
| 37 |
+
token_preview = self.github_token[:7] + "..." + self.github_token[-4:] if len(self.github_token) > 11 else "***"
|
| 38 |
+
logger.info(f"GitHub token configured: {token_preview}")
|
| 39 |
+
else:
|
| 40 |
+
logger.warning("No GitHub token configured - uploads will fail!")
|
| 41 |
+
|
| 42 |
+
# Extract owner and repo name from URL
|
| 43 |
+
if "github.com/" in repo_url:
|
| 44 |
+
parts = repo_url.replace("https://github.com/", "").replace(".git", "").split("/")
|
| 45 |
+
self.owner = parts[0]
|
| 46 |
+
self.repo_name = parts[1]
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Invalid GitHub repository URL format")
|
| 49 |
+
|
| 50 |
+
self.api_base = f"https://api.github.com/repos/{self.owner}/{self.repo_name}"
|
| 51 |
+
self.headers = {
|
| 52 |
+
"Authorization": f"token {self.github_token}",
|
| 53 |
+
"Accept": "application/vnd.github.v3+json",
|
| 54 |
+
"Content-Type": "application/json"
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
logger.info(f"GitHub storage initialized for {self.owner}/{self.repo_name}")
|
| 58 |
+
|
| 59 |
+
def _get_file_sha(self, file_path: str) -> Optional[str]:
|
| 60 |
+
"""
|
| 61 |
+
Get the SHA of an existing file in the repository
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
file_path: Path to file in repository
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
SHA string if file exists, None otherwise
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
url = f"{self.api_base}/contents/{file_path}"
|
| 71 |
+
response = requests.get(url, headers=self.headers)
|
| 72 |
+
|
| 73 |
+
if response.status_code == 200:
|
| 74 |
+
return response.json().get("sha")
|
| 75 |
+
elif response.status_code == 404:
|
| 76 |
+
return None
|
| 77 |
+
else:
|
| 78 |
+
logger.error(f"Error getting file SHA: {response.status_code} - {response.text}")
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Exception getting file SHA: {e}")
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
def _upload_file(self, file_path: str, content: str, message: str, sha: Optional[str] = None) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Upload or update a file in the GitHub repository
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
file_path: Path where file should be stored in repo
|
| 91 |
+
content: File content as string
|
| 92 |
+
message: Commit message
|
| 93 |
+
sha: SHA of existing file (for updates)
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
True if successful, False otherwise
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
# Encode content to base64
|
| 100 |
+
content_encoded = base64.b64encode(content.encode('utf-8')).decode('utf-8')
|
| 101 |
+
|
| 102 |
+
# Prepare request data
|
| 103 |
+
data = {
|
| 104 |
+
"message": message,
|
| 105 |
+
"content": content_encoded
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Add SHA if updating existing file
|
| 109 |
+
if sha:
|
| 110 |
+
data["sha"] = sha
|
| 111 |
+
|
| 112 |
+
# Make API request with timeout
|
| 113 |
+
url = f"{self.api_base}/contents/{file_path}"
|
| 114 |
+
logger.info(f"Uploading to GitHub: {file_path} (size: {len(content)} bytes)")
|
| 115 |
+
response = requests.put(url, headers=self.headers, json=data, timeout=30)
|
| 116 |
+
|
| 117 |
+
if response.status_code in [200, 201]:
|
| 118 |
+
logger.info(f"✓ Successfully uploaded {file_path} to GitHub")
|
| 119 |
+
return True
|
| 120 |
+
elif response.status_code == 401:
|
| 121 |
+
logger.error(f"❌ Authentication failed uploading {file_path}: Invalid or expired GitHub token")
|
| 122 |
+
logger.error(f"Response: {response.text}")
|
| 123 |
+
return False
|
| 124 |
+
elif response.status_code == 403:
|
| 125 |
+
logger.error(f"❌ Permission denied uploading {file_path}: Token lacks required permissions")
|
| 126 |
+
logger.error(f"Response: {response.text}")
|
| 127 |
+
return False
|
| 128 |
+
elif response.status_code == 404:
|
| 129 |
+
logger.error(f"❌ Repository not found: {self.owner}/{self.repo_name}")
|
| 130 |
+
logger.error(f"Response: {response.text}")
|
| 131 |
+
return False
|
| 132 |
+
elif response.status_code == 409:
|
| 133 |
+
logger.error(f"Conflict error uploading {file_path}: File may have been modified. Status: {response.status_code}")
|
| 134 |
+
logger.error(f"Response: {response.text[:500]}")
|
| 135 |
+
return False
|
| 136 |
+
else:
|
| 137 |
+
logger.error(f"Failed to upload {file_path}. Status: {response.status_code}")
|
| 138 |
+
logger.error(f"Response: {response.text}")
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
except requests.exceptions.Timeout as e:
|
| 142 |
+
logger.error(f"Timeout uploading file to GitHub: {e}")
|
| 143 |
+
return False
|
| 144 |
+
except requests.exceptions.RequestException as e:
|
| 145 |
+
logger.error(f"Request exception uploading file to GitHub: {e}")
|
| 146 |
+
return False
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Unexpected exception uploading file to GitHub: {e}")
|
| 149 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
def _get_file_content(self, file_path: str) -> Optional[str]:
|
| 153 |
+
"""
|
| 154 |
+
Get the content of a file from the GitHub repository
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
file_path: Path to file in repository
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
File content as string if successful, None otherwise
|
| 161 |
+
"""
|
| 162 |
+
try:
|
| 163 |
+
url = f"{self.api_base}/contents/{file_path}"
|
| 164 |
+
response = requests.get(url, headers=self.headers)
|
| 165 |
+
|
| 166 |
+
if response.status_code == 200:
|
| 167 |
+
content_encoded = response.json().get("content", "")
|
| 168 |
+
content = base64.b64decode(content_encoded).decode('utf-8')
|
| 169 |
+
return content
|
| 170 |
+
elif response.status_code == 404:
|
| 171 |
+
return None
|
| 172 |
+
else:
|
| 173 |
+
logger.error(f"Error getting file content: {response.status_code} - {response.text}")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"Exception getting file content: {e}")
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
def save_side_effects_report(self, report_data: Dict[str, Any]) -> bool:
|
| 181 |
+
"""
|
| 182 |
+
Save a side effects report to GitHub repository as CSV
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
report_data: Dictionary containing side effects report data
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
True if successful, False otherwise
|
| 189 |
+
"""
|
| 190 |
+
try:
|
| 191 |
+
file_path = "medical_data/side_effects_reports.csv"
|
| 192 |
+
|
| 193 |
+
# Get existing file content
|
| 194 |
+
existing_content = self._get_file_content(file_path)
|
| 195 |
+
|
| 196 |
+
# Define CSV fieldnames
|
| 197 |
+
fieldnames = [
|
| 198 |
+
'timestamp', 'drug_name', 'side_effects', 'patient_age',
|
| 199 |
+
'patient_gender', 'dosage', 'duration', 'severity',
|
| 200 |
+
'outcome', 'additional_details', 'reporter_info', 'raw_input'
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
# Create CSV content
|
| 204 |
+
output = io.StringIO()
|
| 205 |
+
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
| 206 |
+
|
| 207 |
+
# If file doesn't exist, write header
|
| 208 |
+
if existing_content is None:
|
| 209 |
+
writer.writeheader()
|
| 210 |
+
csv_content = output.getvalue()
|
| 211 |
+
else:
|
| 212 |
+
# File exists, append to existing content
|
| 213 |
+
csv_content = existing_content
|
| 214 |
+
|
| 215 |
+
# Append new row
|
| 216 |
+
output = io.StringIO()
|
| 217 |
+
writer = csv.DictWriter(output, fieldnames=fieldnames)
|
| 218 |
+
writer.writerow(report_data)
|
| 219 |
+
new_row = output.getvalue()
|
| 220 |
+
|
| 221 |
+
# Combine existing content with new row
|
| 222 |
+
final_content = csv_content + new_row
|
| 223 |
+
|
| 224 |
+
# Get SHA for update
|
| 225 |
+
sha = self._get_file_sha(file_path)
|
| 226 |
+
|
| 227 |
+
# Upload file
|
| 228 |
+
commit_message = f"Add side effects report for {report_data.get('drug_name', 'unknown drug')} - {report_data.get('timestamp', 'unknown time')}"
|
| 229 |
+
|
| 230 |
+
return self._upload_file(file_path, final_content, commit_message, sha)
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"Error saving side effects report to GitHub: {e}")
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
def save_validation_results(self, evaluation_data: Dict[str, Any]) -> bool:
|
| 237 |
+
"""
|
| 238 |
+
Save validation results to GitHub repository as JSON with robust append logic.
|
| 239 |
+
Always loads existing data first, then appends new evaluation without overwriting.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
evaluation_data: Dictionary containing evaluation data with interaction_id already set
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
True if successful, False otherwise
|
| 246 |
+
"""
|
| 247 |
+
max_retries = 3
|
| 248 |
+
retry_count = 0
|
| 249 |
+
|
| 250 |
+
while retry_count < max_retries:
|
| 251 |
+
try:
|
| 252 |
+
file_path = "medical_data/evaluation_results.json"
|
| 253 |
+
|
| 254 |
+
# STEP 1: Get existing file content with verification
|
| 255 |
+
logger.info(f"Attempt {retry_count + 1}/{max_retries}: Loading existing evaluations from GitHub...")
|
| 256 |
+
existing_content = self._get_file_content(file_path)
|
| 257 |
+
|
| 258 |
+
# STEP 2: Parse existing data or create new list
|
| 259 |
+
evaluations = []
|
| 260 |
+
if existing_content:
|
| 261 |
+
try:
|
| 262 |
+
evaluations = json.loads(existing_content)
|
| 263 |
+
if not isinstance(evaluations, list):
|
| 264 |
+
logger.warning("Existing content is not a list, creating new list")
|
| 265 |
+
evaluations = []
|
| 266 |
+
else:
|
| 267 |
+
logger.info(f"Successfully loaded {len(evaluations)} existing evaluations")
|
| 268 |
+
except json.JSONDecodeError as e:
|
| 269 |
+
logger.error(f"Failed to parse existing evaluation_results.json: {e}")
|
| 270 |
+
# Don't start fresh - this could lose data. Instead, fail and retry.
|
| 271 |
+
if retry_count < max_retries - 1:
|
| 272 |
+
retry_count += 1
|
| 273 |
+
logger.warning(f"Retrying due to JSON parse error...")
|
| 274 |
+
time.sleep(2) # Wait before retry
|
| 275 |
+
continue
|
| 276 |
+
else:
|
| 277 |
+
logger.error("Max retries reached. Cannot parse existing data.")
|
| 278 |
+
return False
|
| 279 |
+
else:
|
| 280 |
+
logger.info("No existing file found, creating new evaluation list")
|
| 281 |
+
|
| 282 |
+
# STEP 3: Verify we're not about to lose data
|
| 283 |
+
new_interaction_id = evaluation_data.get('interaction_id', 'unknown')
|
| 284 |
+
logger.info(f"Adding new evaluation with ID: {new_interaction_id}")
|
| 285 |
+
|
| 286 |
+
# Check if this ID already exists (prevent duplicates)
|
| 287 |
+
existing_ids = [e.get('interaction_id') for e in evaluations]
|
| 288 |
+
if new_interaction_id in existing_ids:
|
| 289 |
+
logger.warning(f"Evaluation with ID {new_interaction_id} already exists. Skipping duplicate.")
|
| 290 |
+
return True # Not an error, just already saved
|
| 291 |
+
|
| 292 |
+
# STEP 4: Add new evaluation to the list (APPEND, not replace)
|
| 293 |
+
evaluations.append(evaluation_data)
|
| 294 |
+
logger.info(f"Appended new evaluation. Total count: {len(evaluations)}")
|
| 295 |
+
|
| 296 |
+
# STEP 5: Convert to JSON string
|
| 297 |
+
json_content = json.dumps(evaluations, indent=2, ensure_ascii=False)
|
| 298 |
+
|
| 299 |
+
# STEP 6: Get SHA for update (must be fresh to avoid conflicts)
|
| 300 |
+
sha = self._get_file_sha(file_path)
|
| 301 |
+
if existing_content and not sha:
|
| 302 |
+
logger.error("File exists but SHA not found. Possible race condition.")
|
| 303 |
+
if retry_count < max_retries - 1:
|
| 304 |
+
retry_count += 1
|
| 305 |
+
logger.warning("Retrying due to SHA retrieval failure...")
|
| 306 |
+
time.sleep(2) # Wait before retry
|
| 307 |
+
continue
|
| 308 |
+
else:
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
# STEP 7: Upload file with the complete list
|
| 312 |
+
commit_message = f"Add validation results for interaction {new_interaction_id} - {evaluation_data.get('timestamp', 'unknown time')}"
|
| 313 |
+
|
| 314 |
+
success = self._upload_file(file_path, json_content, commit_message, sha)
|
| 315 |
+
|
| 316 |
+
if success:
|
| 317 |
+
logger.info(f"✓ Successfully saved evaluation {new_interaction_id}. Total evaluations now: {len(evaluations)}")
|
| 318 |
+
return True
|
| 319 |
+
else:
|
| 320 |
+
logger.error(f"Failed to upload file (attempt {retry_count + 1}/{max_retries})")
|
| 321 |
+
if retry_count < max_retries - 1:
|
| 322 |
+
retry_count += 1
|
| 323 |
+
logger.warning("Retrying upload...")
|
| 324 |
+
time.sleep(2) # Wait before retry
|
| 325 |
+
continue
|
| 326 |
+
else:
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"Error saving validation results to GitHub (attempt {retry_count + 1}/{max_retries}): {e}")
|
| 331 |
+
if retry_count < max_retries - 1:
|
| 332 |
+
retry_count += 1
|
| 333 |
+
logger.warning("Retrying due to exception...")
|
| 334 |
+
time.sleep(2) # Wait before retry
|
| 335 |
+
continue
|
| 336 |
+
else:
|
| 337 |
+
return False
|
| 338 |
+
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
def get_side_effects_reports(self) -> List[Dict[str, Any]]:
|
| 342 |
+
"""
|
| 343 |
+
Get all side effects reports from GitHub repository
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
List of side effects reports as dictionaries
|
| 347 |
+
"""
|
| 348 |
+
try:
|
| 349 |
+
file_path = "medical_data/side_effects_reports.csv"
|
| 350 |
+
content = self._get_file_content(file_path)
|
| 351 |
+
|
| 352 |
+
if not content:
|
| 353 |
+
return []
|
| 354 |
+
|
| 355 |
+
# Parse CSV content
|
| 356 |
+
csv_reader = csv.DictReader(io.StringIO(content))
|
| 357 |
+
reports = list(csv_reader)
|
| 358 |
+
|
| 359 |
+
return reports
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
logger.error(f"Error getting side effects reports from GitHub: {e}")
|
| 363 |
+
return []
|
| 364 |
+
|
| 365 |
+
def get_validation_results(self, limit: int = 10) -> Dict[str, Any]:
|
| 366 |
+
"""
|
| 367 |
+
Get validation results from GitHub repository
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
limit: Maximum number of recent evaluations to return
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Dictionary containing evaluation summary and recent evaluations
|
| 374 |
+
"""
|
| 375 |
+
try:
|
| 376 |
+
file_path = "medical_data/evaluation_results.json"
|
| 377 |
+
content = self._get_file_content(file_path)
|
| 378 |
+
|
| 379 |
+
if not content:
|
| 380 |
+
return {"message": "No evaluations found", "evaluations": []}
|
| 381 |
+
|
| 382 |
+
# Parse JSON content
|
| 383 |
+
evaluations = json.loads(content)
|
| 384 |
+
if not isinstance(evaluations, list):
|
| 385 |
+
evaluations = []
|
| 386 |
+
|
| 387 |
+
# Get recent evaluations
|
| 388 |
+
recent_evaluations = evaluations[-limit:] if evaluations else []
|
| 389 |
+
|
| 390 |
+
# Calculate average scores
|
| 391 |
+
if recent_evaluations:
|
| 392 |
+
total_scores = {
|
| 393 |
+
"accuracy": 0,
|
| 394 |
+
"coherence": 0,
|
| 395 |
+
"relevance": 0,
|
| 396 |
+
"completeness": 0,
|
| 397 |
+
"citations": 0,
|
| 398 |
+
"length": 0,
|
| 399 |
+
"overall": 0
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
count = len(recent_evaluations)
|
| 403 |
+
for eval_data in recent_evaluations:
|
| 404 |
+
report = eval_data.get("validation_report", {})
|
| 405 |
+
total_scores["accuracy"] += int(report.get("Accuracy_Rating", 0))
|
| 406 |
+
total_scores["coherence"] += int(report.get("Coherence_Rating", 0))
|
| 407 |
+
total_scores["relevance"] += int(report.get("Relevance_Rating", 0))
|
| 408 |
+
total_scores["completeness"] += int(report.get("Completeness_Rating", 0))
|
| 409 |
+
total_scores["citations"] += int(report.get("Citations_Attribution_Rating", 0))
|
| 410 |
+
total_scores["length"] += int(report.get("Length_Rating", 0))
|
| 411 |
+
total_scores["overall"] += int(report.get("Overall_Rating", 0))
|
| 412 |
+
|
| 413 |
+
averages = {key: round(value / count, 1) for key, value in total_scores.items()}
|
| 414 |
+
else:
|
| 415 |
+
averages = {}
|
| 416 |
+
|
| 417 |
+
return {
|
| 418 |
+
"total_evaluations": len(evaluations),
|
| 419 |
+
"recent_count": len(recent_evaluations),
|
| 420 |
+
"average_scores": averages,
|
| 421 |
+
"evaluations": recent_evaluations
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
logger.error(f"Error getting validation results from GitHub: {e}")
|
| 426 |
+
return {"error": str(e), "evaluations": []}
|
| 427 |
+
|
| 428 |
+
def get_drug_reports(self, drug_name: str) -> List[Dict[str, Any]]:
|
| 429 |
+
"""
|
| 430 |
+
Get side effects reports for a specific drug from GitHub repository
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
drug_name: Name of the drug to filter reports
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
List of reports for the specified drug
|
| 437 |
+
"""
|
| 438 |
+
try:
|
| 439 |
+
all_reports = self.get_side_effects_reports()
|
| 440 |
+
|
| 441 |
+
# Filter reports for the specific drug (case-insensitive)
|
| 442 |
+
drug_reports = [
|
| 443 |
+
report for report in all_reports
|
| 444 |
+
if report.get('drug_name', '').lower() == drug_name.lower()
|
| 445 |
+
]
|
| 446 |
+
|
| 447 |
+
return drug_reports
|
| 448 |
+
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"Error getting drug reports from GitHub: {e}")
|
| 451 |
+
return []
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Global GitHub storage instance
|
| 455 |
+
_github_storage = None
|
| 456 |
+
|
| 457 |
+
def get_github_storage() -> GitHubStorage:
|
| 458 |
+
"""Get the global GitHub storage instance with lazy loading."""
|
| 459 |
+
global _github_storage
|
| 460 |
+
if _github_storage is None:
|
| 461 |
+
_github_storage = GitHubStorage()
|
| 462 |
+
return _github_storage
|
core/hbv_assessment.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HBV Patient Assessment Module
|
| 3 |
+
Evaluates patient eligibility for HBV treatment according to SASLT 2021 guidelines
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from .retrievers import hybrid_search
|
| 9 |
+
from .config import get_llm
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_patient_query(patient_data: Dict[str, Any]) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Create a comprehensive search query based on patient parameters
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
patient_data: Dictionary containing patient clinical parameters
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Optimized search query string for guideline retrieval
|
| 23 |
+
"""
|
| 24 |
+
query_parts = []
|
| 25 |
+
|
| 26 |
+
# Add HBeAg status to query
|
| 27 |
+
if patient_data.get("hbeag_status") == "Positive":
|
| 28 |
+
query_parts.append("HBeAg-positive chronic hepatitis B treatment eligibility")
|
| 29 |
+
else:
|
| 30 |
+
query_parts.append("HBeAg-negative chronic hepatitis B treatment eligibility")
|
| 31 |
+
|
| 32 |
+
# Add viral load context
|
| 33 |
+
hbv_dna = patient_data.get("hbv_dna_level", 0)
|
| 34 |
+
if hbv_dna > 20000:
|
| 35 |
+
query_parts.append("high HBV DNA level")
|
| 36 |
+
elif hbv_dna > 2000:
|
| 37 |
+
query_parts.append("moderate HBV DNA level")
|
| 38 |
+
|
| 39 |
+
# Add ALT context
|
| 40 |
+
sex = patient_data.get("sex", "Male")
|
| 41 |
+
alt_level = patient_data.get("alt_level", 0)
|
| 42 |
+
alt_uln = 35 if sex == "Male" else 25
|
| 43 |
+
if alt_level > 2 * alt_uln:
|
| 44 |
+
query_parts.append("significantly elevated ALT")
|
| 45 |
+
elif alt_level > alt_uln:
|
| 46 |
+
query_parts.append("elevated ALT")
|
| 47 |
+
|
| 48 |
+
# Add fibrosis context
|
| 49 |
+
fibrosis_stage = patient_data.get("fibrosis_stage", "")
|
| 50 |
+
if fibrosis_stage == "F4":
|
| 51 |
+
query_parts.append("cirrhosis treatment criteria")
|
| 52 |
+
elif fibrosis_stage == "F2-F3":
|
| 53 |
+
query_parts.append("significant fibrosis")
|
| 54 |
+
|
| 55 |
+
# Add special populations
|
| 56 |
+
if patient_data.get("pregnancy_status") == "Pregnant":
|
| 57 |
+
query_parts.append("pregnancy antiviral prophylaxis")
|
| 58 |
+
|
| 59 |
+
immunosuppression = patient_data.get("immunosuppression_status")
|
| 60 |
+
if immunosuppression and immunosuppression != "None":
|
| 61 |
+
query_parts.append("immunosuppression prophylactic therapy")
|
| 62 |
+
|
| 63 |
+
coinfections = patient_data.get("coinfections", [])
|
| 64 |
+
if coinfections:
|
| 65 |
+
query_parts.append("coinfection management")
|
| 66 |
+
|
| 67 |
+
if patient_data.get("extrahepatic_manifestations"):
|
| 68 |
+
query_parts.append("extrahepatic manifestations")
|
| 69 |
+
|
| 70 |
+
if patient_data.get("family_history_cirrhosis_hcc"):
|
| 71 |
+
query_parts.append("family history HCC cirrhosis")
|
| 72 |
+
|
| 73 |
+
# Combine into search query
|
| 74 |
+
base_query = " ".join(query_parts[:3]) # Use top 3 most relevant parts
|
| 75 |
+
return f"{base_query} treatment indications criteria SASLT 2021"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def assess_hbv_eligibility(patient_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 79 |
+
"""
|
| 80 |
+
Assess patient eligibility for HBV treatment based on SASLT 2021 guidelines
|
| 81 |
+
using retrieval from vector store and LLM analysis
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
patient_data: Dictionary containing patient clinical parameters
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Dictionary with assessment results:
|
| 88 |
+
- eligible: bool
|
| 89 |
+
- recommendations: str (comprehensive narrative with inline citations in format [SASLT 2021, Page X])
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
# Check if HBsAg is positive (required for treatment consideration)
|
| 93 |
+
if patient_data.get("hbsag_status") != "Positive":
|
| 94 |
+
return {
|
| 95 |
+
"eligible": False,
|
| 96 |
+
"recommendations": "Patient is HBsAg negative. HBV treatment is not indicated. HBsAg positivity is required for HBV treatment consideration according to SASLT 2021 guidelines."
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# Create search query based on patient parameters
|
| 100 |
+
search_query = create_patient_query(patient_data)
|
| 101 |
+
logger.info(f"Generated search query: {search_query}")
|
| 102 |
+
|
| 103 |
+
# Retrieve relevant guidelines from vector store
|
| 104 |
+
docs = hybrid_search(search_query, provider="SASLT", k_vector=8, k_bm25=2)
|
| 105 |
+
|
| 106 |
+
if not docs:
|
| 107 |
+
logger.warning("No documents retrieved from vector store")
|
| 108 |
+
return {
|
| 109 |
+
"eligible": False,
|
| 110 |
+
"recommendations": "Unable to retrieve SASLT 2021 guidelines. Please try again."
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Log retrieved documents
|
| 114 |
+
logger.info(f"\n{'='*80}")
|
| 115 |
+
logger.info(f"RETRIEVED DOCUMENTS ({len(docs)} documents)")
|
| 116 |
+
logger.info(f"{'='*80}")
|
| 117 |
+
for i, doc in enumerate(docs, 1):
|
| 118 |
+
source = doc.metadata.get('source', 'Unknown')
|
| 119 |
+
page = doc.metadata.get('page_number', 'N/A')
|
| 120 |
+
content_preview = doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
|
| 121 |
+
logger.info(f"\n📄 Document {i}:")
|
| 122 |
+
logger.info(f" Source: {source}")
|
| 123 |
+
logger.info(f" Page: {page}")
|
| 124 |
+
logger.info(f" Content Preview: {content_preview}")
|
| 125 |
+
logger.info(f"{'='*80}\n")
|
| 126 |
+
|
| 127 |
+
# Format retrieved documents for LLM
|
| 128 |
+
context = "\n\n".join([
|
| 129 |
+
f"[Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page_number', 'N/A')}]\n{doc.page_content}"
|
| 130 |
+
for doc in docs
|
| 131 |
+
])
|
| 132 |
+
|
| 133 |
+
# Define ALT ULN for context
|
| 134 |
+
sex = patient_data.get("sex", "Male")
|
| 135 |
+
alt_uln = 35 if sex == "Male" else 25
|
| 136 |
+
|
| 137 |
+
# Format patient data for prompt
|
| 138 |
+
age = patient_data.get("age", "N/A")
|
| 139 |
+
pregnancy_status = patient_data.get("pregnancy_status", "N/A")
|
| 140 |
+
hbsag_status = patient_data.get("hbsag_status", "N/A")
|
| 141 |
+
duration_hbsag = patient_data.get("duration_hbsag_months", "N/A")
|
| 142 |
+
hbv_dna = patient_data.get("hbv_dna_level", 0)
|
| 143 |
+
hbeag_status = patient_data.get("hbeag_status", "N/A")
|
| 144 |
+
alt_level = patient_data.get("alt_level", 0)
|
| 145 |
+
fibrosis_stage = patient_data.get("fibrosis_stage", "N/A")
|
| 146 |
+
necroinflammatory = patient_data.get("necroinflammatory_activity", "N/A")
|
| 147 |
+
extrahepatic = patient_data.get("extrahepatic_manifestations", False)
|
| 148 |
+
immunosuppression = patient_data.get("immunosuppression_status", "None")
|
| 149 |
+
coinfections = patient_data.get("coinfections", [])
|
| 150 |
+
family_history = patient_data.get("family_history_cirrhosis_hcc", False)
|
| 151 |
+
comorbidities = patient_data.get("other_comorbidities", [])
|
| 152 |
+
|
| 153 |
+
# Create prompt for LLM to analyze patient against guidelines
|
| 154 |
+
analysis_prompt = f"""You are an HBV treatment eligibility assessment system. Analyze the patient data against SASLT 2021 guidelines.
|
| 155 |
+
|
| 156 |
+
PATIENT DATA:
|
| 157 |
+
- Sex: {sex}
|
| 158 |
+
- Age: {age} years
|
| 159 |
+
- Pregnancy Status: {pregnancy_status}
|
| 160 |
+
- HBsAg Status: {hbsag_status}
|
| 161 |
+
- HBsAg Duration: {duration_hbsag} months
|
| 162 |
+
- HBV DNA Level: {hbv_dna:,.0f} IU/mL
|
| 163 |
+
- HBeAg Status: {hbeag_status}
|
| 164 |
+
- ALT Level: {alt_level} U/L (ULN: {alt_uln} U/L)
|
| 165 |
+
- Fibrosis Stage: {fibrosis_stage}
|
| 166 |
+
- Necroinflammatory Activity: {necroinflammatory}
|
| 167 |
+
- Extrahepatic Manifestations: {extrahepatic}
|
| 168 |
+
- Immunosuppression: {immunosuppression}
|
| 169 |
+
- Coinfections: {', '.join(coinfections) if coinfections else 'None'}
|
| 170 |
+
- Family History (Cirrhosis/HCC): {family_history}
|
| 171 |
+
- Other Comorbidities: {', '.join(comorbidities) if comorbidities else 'None'}
|
| 172 |
+
|
| 173 |
+
HBV ELIGIBILITY CRITERIA & TREATMENT OPTIONS – SASLT 2021:
|
| 174 |
+
|
| 175 |
+
TREATMENT ELIGIBILITY CRITERIA:
|
| 176 |
+
1. HBV DNA > 2,000 IU/mL, ALT > ULN, regardless of HBeAg status, and/or at least moderate liver necroinflammation or fibrosis (Grade A)
|
| 177 |
+
2. Patients with cirrhosis (compensated or decompensated), with any detectable HBV DNA level and regardless of ALT levels (Grade A)
|
| 178 |
+
3. HBV DNA > 20,000 IU/mL and ALT > 2xULN, regardless of the degree of fibrosis (Grade B)
|
| 179 |
+
4. HBeAg-positive chronic HBV infection (persistently normal ALT and high HBV DNA levels) may be treated if they are > 30 years, regardless of the severity of liver histological lesions (Grade D)
|
| 180 |
+
5. HBV DNA > 2,000 IU/mL, ALT > ULN, regardless of HBeAg status, and a family history of HCC or cirrhosis and extrahepatic manifestations (Grade D)
|
| 181 |
+
|
| 182 |
+
TREATMENT CHOICES:
|
| 183 |
+
- Preferred regimens are ETV (entecavir), TDF (tenofovir disoproxil fumarate), and TAF (tenofovir alafenamide) as monotherapies (Grade A)
|
| 184 |
+
|
| 185 |
+
MANAGEMENT ALGORITHM:
|
| 186 |
+
• HBsAg positive with chronic HBV infection and no signs of chronic hepatitis → Monitor (HBsAg, HBeAg, HBV DNA, ALT, fibrosis assessment). Consider: risk of HCC, risk of HBV reactivation, extrahepatic manifestations, risk of HBV transmission
|
| 187 |
+
• CHB (with/without cirrhosis) → Start antiviral treatment if indicated, otherwise return to monitoring
|
| 188 |
+
• HBsAg negative, anti-HBc positive → No specialist follow-up (inform about HBV reactivation risk). In case of immunosuppression: start oral antiviral prophylaxis or monitor
|
| 189 |
+
|
| 190 |
+
SASLT 2021 GUIDELINES (Retrieved Context):
|
| 191 |
+
{context}
|
| 192 |
+
|
| 193 |
+
Based STRICTLY on the SASLT 2021 guidelines and criteria provided above, assess this patient's eligibility for HBV antiviral treatment.
|
| 194 |
+
|
| 195 |
+
You MUST respond with a valid JSON object in this exact format:
|
| 196 |
+
{{
|
| 197 |
+
"eligible": true or false,
|
| 198 |
+
"recommendations": "Comprehensive assessment with inline citations"
|
| 199 |
+
}}
|
| 200 |
+
|
| 201 |
+
CRITICAL CITATION REQUIREMENTS:
|
| 202 |
+
1. The "recommendations" field must be a comprehensive narrative that includes:
|
| 203 |
+
- Eligibility determination with rationale
|
| 204 |
+
- Specific criteria met or not met from the guidelines
|
| 205 |
+
- Treatment options if eligible (ETV, TDF, TAF as first-line agents)
|
| 206 |
+
- Special considerations (pregnancy, immunosuppression, coinfections, etc.)
|
| 207 |
+
- Any additional clinical notes
|
| 208 |
+
- **References** section at the end listing all cited pages
|
| 209 |
+
|
| 210 |
+
2. EVERY statement in recommendations MUST include inline citations in this format:
|
| 211 |
+
"[SASLT 2021, Page X]" where X is the specific page number
|
| 212 |
+
|
| 213 |
+
3. Example format:
|
| 214 |
+
"Patient meets treatment criteria based on HBV DNA > 2,000 IU/mL, ALT > ULN, and moderate fibrosis (Grade A) [SASLT 2021, Page 12]. First-line antiviral agents including entecavir (ETV), tenofovir disoproxil fumarate (TDF), or tenofovir alafenamide (TAF) are recommended [SASLT 2021, Page 15]. Patient should be monitored for treatment response [SASLT 2021, Page 18].
|
| 215 |
+
|
| 216 |
+
**References**
|
| 217 |
+
SASLT 2021 Guidelines - Pages: 12, 15, 18
|
| 218 |
+
(Treatment Eligibility Criteria, First-Line Antiviral Agents, Monitoring Protocols)"
|
| 219 |
+
|
| 220 |
+
4. ALWAYS cite the specific page number from the [Source: ..., Page: X] markers in the guidelines above
|
| 221 |
+
|
| 222 |
+
5. Include evidence grade (Grade A, B, C, D) when available in the guidelines
|
| 223 |
+
|
| 224 |
+
6. END the recommendations with a **References** section that lists all cited pages in ascending order with brief description of topics covered
|
| 225 |
+
|
| 226 |
+
IMPORTANT:
|
| 227 |
+
1. Base your assessment ONLY on the SASLT 2021 guidelines provided
|
| 228 |
+
2. Make recommendations comprehensive and detailed
|
| 229 |
+
3. Cite page numbers after EVERY clinical statement or recommendation
|
| 230 |
+
4. Use the format [SASLT 2021, Page X] for all citations
|
| 231 |
+
5. Include a **References** section at the end listing all pages cited
|
| 232 |
+
6. Return ONLY the JSON object, no additional text
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
# Log the complete prompt being sent to LLM
|
| 236 |
+
logger.info(f"\n{'='*80}")
|
| 237 |
+
logger.info(f"LLM PROMPT")
|
| 238 |
+
logger.info(f"{'='*80}")
|
| 239 |
+
logger.info(f"\n{analysis_prompt}\n")
|
| 240 |
+
logger.info(f"{'='*80}\n")
|
| 241 |
+
|
| 242 |
+
# Get LLM response
|
| 243 |
+
llm = get_llm()
|
| 244 |
+
logger.info("Sending prompt to LLM...")
|
| 245 |
+
response = llm.invoke(analysis_prompt)
|
| 246 |
+
logger.info("LLM response received")
|
| 247 |
+
|
| 248 |
+
# Extract JSON from response
|
| 249 |
+
response_text = response.content if hasattr(response, 'content') else str(response)
|
| 250 |
+
|
| 251 |
+
# Log LLM response
|
| 252 |
+
logger.info(f"\n{'='*80}")
|
| 253 |
+
logger.info(f"LLM RESPONSE")
|
| 254 |
+
logger.info(f"{'='*80}")
|
| 255 |
+
logger.info(f"\n{response_text}\n")
|
| 256 |
+
logger.info(f"{'='*80}\n")
|
| 257 |
+
|
| 258 |
+
# Try to parse JSON from response
|
| 259 |
+
try:
|
| 260 |
+
# Find JSON in response (handle cases where LLM adds extra text)
|
| 261 |
+
json_start = response_text.find('{')
|
| 262 |
+
json_end = response_text.rfind('}') + 1
|
| 263 |
+
if json_start >= 0 and json_end > json_start:
|
| 264 |
+
json_str = response_text[json_start:json_end]
|
| 265 |
+
result = json.loads(json_str)
|
| 266 |
+
logger.info(f"✅ Successfully parsed JSON response")
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError("No JSON found in response")
|
| 269 |
+
|
| 270 |
+
# Validate and return result
|
| 271 |
+
assessment_result = {
|
| 272 |
+
"eligible": result.get("eligible", False),
|
| 273 |
+
"recommendations": result.get("recommendations", "")
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
# Log final assessment
|
| 277 |
+
logger.info(f"\n{'='*80}")
|
| 278 |
+
logger.info(f"FINAL ASSESSMENT")
|
| 279 |
+
logger.info(f"{'='*80}")
|
| 280 |
+
logger.info(f"Eligible: {assessment_result['eligible']}")
|
| 281 |
+
logger.info(f"Recommendations length: {len(assessment_result['recommendations'])} characters")
|
| 282 |
+
logger.info(f"{'='*80}\n")
|
| 283 |
+
|
| 284 |
+
return assessment_result
|
| 285 |
+
|
| 286 |
+
except (json.JSONDecodeError, ValueError) as e:
|
| 287 |
+
logger.error(f"Failed to parse LLM response as JSON: {e}")
|
| 288 |
+
logger.error(f"Response text: {response_text}")
|
| 289 |
+
|
| 290 |
+
# Fallback: return error response
|
| 291 |
+
return {
|
| 292 |
+
"eligible": False,
|
| 293 |
+
"recommendations": f"Error parsing assessment results. Please try again. Error details: {str(e)}"
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.error(f"Error in assess_hbv_eligibility: {str(e)}")
|
| 298 |
+
raise
|
core/medical_terminology.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medical Terminology Module for HBV (Hepatitis B Virus)
|
| 3 |
+
|
| 4 |
+
This module provides intelligent handling of HBV medical linguistic variability including:
|
| 5 |
+
- Synonyms and alternate terms
|
| 6 |
+
- Abbreviations and acronyms (with context awareness)
|
| 7 |
+
- Regional spelling variations (US/UK/International)
|
| 8 |
+
- Specialty-specific terminology
|
| 9 |
+
- Dynamic learning from corpus
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import json
|
| 14 |
+
from typing import List, Dict, Set, Tuple, Optional
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from .config import logger
|
| 18 |
+
|
| 19 |
+
# ============================================================================
|
| 20 |
+
# CORE HBV MEDICAL TERMINOLOGY MAPPINGS
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
# Common HBV medical abbreviations with context-aware expansions
|
| 24 |
+
MEDICAL_ABBREVIATIONS = {
|
| 25 |
+
# HBV Terminology
|
| 26 |
+
"hbv": ["hepatitis b virus", "hepatitis b"],
|
| 27 |
+
"hbsag": ["hepatitis b surface antigen", "hbs antigen"],
|
| 28 |
+
"hbeag": ["hepatitis b e antigen", "hbe antigen"],
|
| 29 |
+
"hbcag": ["hepatitis b core antigen"],
|
| 30 |
+
"anti-hbs": ["antibody to hepatitis b surface antigen", "anti-hbs antibody"],
|
| 31 |
+
"anti-hbe": ["antibody to hepatitis b e antigen"],
|
| 32 |
+
"anti-hbc": ["antibody to hepatitis b core antigen"],
|
| 33 |
+
"hbv dna": ["hepatitis b virus dna", "hbv viral load"],
|
| 34 |
+
|
| 35 |
+
# Liver Disease Terms
|
| 36 |
+
"alt": ["alanine aminotransferase", "alanine transaminase", "sgpt"],
|
| 37 |
+
"ast": ["aspartate aminotransferase", "aspartate transaminase", "sgot"],
|
| 38 |
+
"alp": ["alkaline phosphatase"],
|
| 39 |
+
"ggt": ["gamma-glutamyl transferase", "gamma glutamyl transpeptidase"],
|
| 40 |
+
"inr": ["international normalized ratio"],
|
| 41 |
+
"pt": ["prothrombin time"],
|
| 42 |
+
"apri": ["ast to platelet ratio index"],
|
| 43 |
+
"fib-4": ["fibrosis-4 index"],
|
| 44 |
+
|
| 45 |
+
# Fibrosis Staging
|
| 46 |
+
"f0": ["no fibrosis"],
|
| 47 |
+
"f1": ["mild fibrosis", "portal fibrosis"],
|
| 48 |
+
"f2": ["moderate fibrosis"],
|
| 49 |
+
"f3": ["severe fibrosis", "advanced fibrosis"],
|
| 50 |
+
"f4": ["cirrhosis", "liver cirrhosis"],
|
| 51 |
+
|
| 52 |
+
# Necroinflammatory Activity
|
| 53 |
+
"a0": ["no activity"],
|
| 54 |
+
"a1": ["mild activity"],
|
| 55 |
+
"a2": ["moderate activity"],
|
| 56 |
+
"a3": ["severe activity"],
|
| 57 |
+
|
| 58 |
+
# Treatment Terms
|
| 59 |
+
"etv": ["entecavir"],
|
| 60 |
+
"tdf": ["tenofovir disoproxil fumarate", "tenofovir df"],
|
| 61 |
+
"taf": ["tenofovir alafenamide"],
|
| 62 |
+
"lam": ["lamivudine", "3tc"],
|
| 63 |
+
"adv": ["adefovir", "adefovir dipivoxil"],
|
| 64 |
+
"ldv": ["telbivudine"],
|
| 65 |
+
"peg-ifn": ["pegylated interferon", "peginterferon"],
|
| 66 |
+
"ifn": ["interferon"],
|
| 67 |
+
|
| 68 |
+
# Complications
|
| 69 |
+
"hcc": ["hepatocellular carcinoma", "liver cancer"],
|
| 70 |
+
"dc": ["decompensated cirrhosis"],
|
| 71 |
+
"cc": ["compensated cirrhosis"],
|
| 72 |
+
"esld": ["end-stage liver disease"],
|
| 73 |
+
"alf": ["acute liver failure"],
|
| 74 |
+
"aclf": ["acute-on-chronic liver failure"],
|
| 75 |
+
|
| 76 |
+
# Coinfections
|
| 77 |
+
"hiv": ["human immunodeficiency virus"],
|
| 78 |
+
"hcv": ["hepatitis c virus", "hepatitis c"],
|
| 79 |
+
"hdv": ["hepatitis d virus", "hepatitis delta"],
|
| 80 |
+
"hav": ["hepatitis a virus", "hepatitis a"],
|
| 81 |
+
|
| 82 |
+
# Clinical Terms
|
| 83 |
+
"uln": ["upper limit of normal"],
|
| 84 |
+
"iu/ml": ["international units per milliliter"],
|
| 85 |
+
"log": ["logarithm", "log10"],
|
| 86 |
+
"svr": ["sustained virological response"],
|
| 87 |
+
"vr": ["virological response"],
|
| 88 |
+
"br": ["biochemical response"],
|
| 89 |
+
"sr": ["serological response"],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Synonym mappings for HBV medical terms
|
| 93 |
+
MEDICAL_SYNONYMS = {
|
| 94 |
+
# HBV terminology
|
| 95 |
+
"hepatitis b": ["hbv", "hepatitis b virus", "hep b", "hbv infection"],
|
| 96 |
+
"chronic hepatitis b": ["chb", "chronic hbv", "chronic hbv infection"],
|
| 97 |
+
"acute hepatitis b": ["ahb", "acute hbv"],
|
| 98 |
+
"hbv dna": ["viral load", "hbv viral load", "serum hbv dna"],
|
| 99 |
+
|
| 100 |
+
# Serological markers
|
| 101 |
+
"hbsag positive": ["hbsag+", "hbs antigen positive"],
|
| 102 |
+
"hbeag positive": ["hbeag+", "hbe antigen positive"],
|
| 103 |
+
"hbsag negative": ["hbsag-", "hbs antigen negative"],
|
| 104 |
+
"hbeag negative": ["hbeag-", "hbe antigen negative"],
|
| 105 |
+
|
| 106 |
+
# Liver disease stages
|
| 107 |
+
"cirrhosis": ["f4", "liver cirrhosis", "hepatic cirrhosis"],
|
| 108 |
+
"fibrosis": ["liver fibrosis", "hepatic fibrosis"],
|
| 109 |
+
"compensated cirrhosis": ["cc", "child-pugh a", "child-pugh b"],
|
| 110 |
+
"decompensated cirrhosis": ["dc", "child-pugh c"],
|
| 111 |
+
|
| 112 |
+
# Treatment terms
|
| 113 |
+
"antiviral therapy": ["antiviral treatment", "nucleos(t)ide analogue", "na therapy"],
|
| 114 |
+
"entecavir": ["etv", "baraclude"],
|
| 115 |
+
"tenofovir": ["tdf", "taf", "viread", "vemlidy"],
|
| 116 |
+
"interferon": ["ifn", "pegylated interferon", "peg-ifn"],
|
| 117 |
+
|
| 118 |
+
# Clinical outcomes
|
| 119 |
+
"treatment response": ["virological response", "biochemical response"],
|
| 120 |
+
"viral suppression": ["undetectable hbv dna", "hbv dna < lloq"],
|
| 121 |
+
"alt normalization": ["alt normal", "alt within normal limits"],
|
| 122 |
+
|
| 123 |
+
# Complications
|
| 124 |
+
"hepatocellular carcinoma": ["hcc", "liver cancer", "primary liver cancer"],
|
| 125 |
+
"liver failure": ["hepatic failure", "end-stage liver disease", "esld"],
|
| 126 |
+
"portal hypertension": ["esophageal varices", "ascites", "splenomegaly"],
|
| 127 |
+
|
| 128 |
+
# Special populations
|
| 129 |
+
"pregnant women": ["pregnancy", "pregnant patients"],
|
| 130 |
+
"immunosuppressed": ["immunocompromised", "on immunosuppression"],
|
| 131 |
+
"coinfection": ["co-infection", "dual infection"],
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# Regional spelling variations (US/UK/International)
|
| 135 |
+
SPELLING_VARIATIONS = {
|
| 136 |
+
"fibrosis": ["fibrosis"],
|
| 137 |
+
"cirrhosis": ["cirrhosis"],
|
| 138 |
+
"anaemia": ["anemia"],
|
| 139 |
+
"haemorrhage": ["hemorrhage"],
|
| 140 |
+
"oesophageal": ["esophageal"],
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Context-specific term preferences
|
| 144 |
+
CONTEXT_PREFERENCES = {
|
| 145 |
+
"treatment": ["antiviral", "therapy", "regimen", "medication"],
|
| 146 |
+
"diagnosis": ["hbsag", "hbeag", "hbv dna", "serology"],
|
| 147 |
+
"monitoring": ["alt", "hbv dna", "liver function", "fibrosis"],
|
| 148 |
+
"complications": ["hcc", "cirrhosis", "decompensation", "liver failure"],
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# ============================================================================
|
| 152 |
+
# DYNAMIC TERMINOLOGY LEARNING
|
| 153 |
+
# ============================================================================
|
| 154 |
+
|
| 155 |
+
class MedicalTerminologyExpander:
|
| 156 |
+
"""
|
| 157 |
+
Dynamically learns and expands medical terminology from corpus.
|
| 158 |
+
Handles abbreviations, synonyms, and context-specific variations for HBV.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, corpus_path: Optional[Path] = None):
|
| 162 |
+
"""Initialize with optional corpus for dynamic learning."""
|
| 163 |
+
self.abbreviations = MEDICAL_ABBREVIATIONS.copy()
|
| 164 |
+
self.synonyms = MEDICAL_SYNONYMS.copy()
|
| 165 |
+
self.spelling_vars = SPELLING_VARIATIONS.copy()
|
| 166 |
+
self.learned_terms = defaultdict(set)
|
| 167 |
+
|
| 168 |
+
if corpus_path and corpus_path.exists():
|
| 169 |
+
self._learn_from_corpus(corpus_path)
|
| 170 |
+
|
| 171 |
+
def expand_query(self, query: str, context: Optional[str] = None) -> List[str]:
|
| 172 |
+
"""
|
| 173 |
+
Expand a query with medical synonyms and abbreviations.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
query: Original query string
|
| 177 |
+
context: Optional context hint (e.g., 'treatment', 'diagnosis')
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
List of expanded query variations
|
| 181 |
+
"""
|
| 182 |
+
expansions = [query]
|
| 183 |
+
query_lower = query.lower()
|
| 184 |
+
|
| 185 |
+
# Expand abbreviations
|
| 186 |
+
for abbrev, full_forms in self.abbreviations.items():
|
| 187 |
+
if abbrev in query_lower:
|
| 188 |
+
for full_form in full_forms:
|
| 189 |
+
expansions.append(query_lower.replace(abbrev, full_form))
|
| 190 |
+
|
| 191 |
+
# Expand synonyms
|
| 192 |
+
for term, synonyms in self.synonyms.items():
|
| 193 |
+
if term in query_lower:
|
| 194 |
+
for synonym in synonyms:
|
| 195 |
+
expansions.append(query_lower.replace(term, synonym))
|
| 196 |
+
|
| 197 |
+
# Add context-specific preferences
|
| 198 |
+
if context and context in CONTEXT_PREFERENCES:
|
| 199 |
+
for pref_term in CONTEXT_PREFERENCES[context]:
|
| 200 |
+
if pref_term not in query_lower:
|
| 201 |
+
expansions.append(f"{query} {pref_term}")
|
| 202 |
+
|
| 203 |
+
# Remove duplicates while preserving order
|
| 204 |
+
seen = set()
|
| 205 |
+
unique_expansions = []
|
| 206 |
+
for exp in expansions:
|
| 207 |
+
if exp not in seen:
|
| 208 |
+
seen.add(exp)
|
| 209 |
+
unique_expansions.append(exp)
|
| 210 |
+
|
| 211 |
+
return unique_expansions
|
| 212 |
+
|
| 213 |
+
def normalize_term(self, term: str) -> str:
|
| 214 |
+
"""
|
| 215 |
+
Normalize a medical term to its canonical form.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
term: Medical term to normalize
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Normalized canonical form
|
| 222 |
+
"""
|
| 223 |
+
term_lower = term.lower().strip()
|
| 224 |
+
|
| 225 |
+
# Check if it's an abbreviation
|
| 226 |
+
if term_lower in self.abbreviations:
|
| 227 |
+
return self.abbreviations[term_lower][0]
|
| 228 |
+
|
| 229 |
+
# Check if it's a synonym
|
| 230 |
+
for canonical, synonyms in self.synonyms.items():
|
| 231 |
+
if term_lower in synonyms or term_lower == canonical:
|
| 232 |
+
return canonical
|
| 233 |
+
|
| 234 |
+
# Check spelling variations
|
| 235 |
+
for canonical, variations in self.spelling_vars.items():
|
| 236 |
+
if term_lower in variations:
|
| 237 |
+
return canonical
|
| 238 |
+
|
| 239 |
+
return term
|
| 240 |
+
|
| 241 |
+
def _learn_from_corpus(self, corpus_path: Path):
|
| 242 |
+
"""Learn new terminology patterns from corpus."""
|
| 243 |
+
try:
|
| 244 |
+
# Implementation for dynamic learning from HBV guidelines
|
| 245 |
+
logger.info(f"Learning terminology from corpus: {corpus_path}")
|
| 246 |
+
# This would analyze the corpus and extract new term relationships
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.warning(f"Could not learn from corpus: {e}")
|
| 249 |
+
|
| 250 |
+
def get_related_terms(self, term: str, max_terms: int = 5) -> List[str]:
|
| 251 |
+
"""
|
| 252 |
+
Get related medical terms for a given term.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
term: Medical term
|
| 256 |
+
max_terms: Maximum number of related terms to return
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
List of related terms
|
| 260 |
+
"""
|
| 261 |
+
related = set()
|
| 262 |
+
term_lower = term.lower()
|
| 263 |
+
|
| 264 |
+
# Find synonyms
|
| 265 |
+
for canonical, synonyms in self.synonyms.items():
|
| 266 |
+
if term_lower == canonical or term_lower in synonyms:
|
| 267 |
+
related.update(synonyms)
|
| 268 |
+
related.add(canonical)
|
| 269 |
+
|
| 270 |
+
# Find abbreviations
|
| 271 |
+
if term_lower in self.abbreviations:
|
| 272 |
+
related.update(self.abbreviations[term_lower])
|
| 273 |
+
|
| 274 |
+
# Remove the original term
|
| 275 |
+
related.discard(term_lower)
|
| 276 |
+
|
| 277 |
+
return list(related)[:max_terms]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Global instance for easy access
|
| 281 |
+
_global_expander = None
|
| 282 |
+
|
| 283 |
+
def get_terminology_expander() -> MedicalTerminologyExpander:
|
| 284 |
+
"""Get or create the global terminology expander instance."""
|
| 285 |
+
global _global_expander
|
| 286 |
+
if _global_expander is None:
|
| 287 |
+
_global_expander = MedicalTerminologyExpander()
|
| 288 |
+
return _global_expander
|
core/retrievers.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
from . import utils
|
| 5 |
+
from langchain_community.retrievers import BM25Retriever
|
| 6 |
+
from langchain.retrievers import EnsembleRetriever
|
| 7 |
+
from langchain.schema import Document
|
| 8 |
+
from .config import logger
|
| 9 |
+
from .tracing import traceable
|
| 10 |
+
|
| 11 |
+
# Global configuration for retrieval parameters
|
| 12 |
+
# Increased for more comprehensive context and complete answers
|
| 13 |
+
DEFAULT_K_VECTOR = 1 # Number of documents to retrieve from vector search
|
| 14 |
+
DEFAULT_K_BM25 = 1 # Number of documents to retrieve from BM25 search
|
| 15 |
+
|
| 16 |
+
# Global variables for lazy loading
|
| 17 |
+
_vector_store = None
|
| 18 |
+
_chunks = None
|
| 19 |
+
_vector_retriever = None
|
| 20 |
+
_bm25_retriever = None
|
| 21 |
+
_hybrid_retriever = None
|
| 22 |
+
_initialized = False
|
| 23 |
+
|
| 24 |
+
def _ensure_initialized():
|
| 25 |
+
"""Initialize retrievers on first use (lazy loading for faster startup)"""
|
| 26 |
+
global _vector_store, _chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
|
| 27 |
+
|
| 28 |
+
if _initialized:
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
logger.info("🔄 Initializing retrievers (first time use)...")
|
| 32 |
+
|
| 33 |
+
# Process any new data and update vector store and chunks cache
|
| 34 |
+
try:
|
| 35 |
+
logger.info("🔄 Processing new data and updating vector store if needed...")
|
| 36 |
+
_vector_store = utils.process_new_data_and_update_vector_store()
|
| 37 |
+
if _vector_store is None:
|
| 38 |
+
# Fall back to load existing if processing found no new files
|
| 39 |
+
_vector_store = utils.load_vector_store()
|
| 40 |
+
if _vector_store is None:
|
| 41 |
+
# As a last resort, create from whatever is already in cache (if any)
|
| 42 |
+
logger.info("ℹ️ No vector store found; attempting creation from cached chunks...")
|
| 43 |
+
cached_chunks = utils.load_chunks() or []
|
| 44 |
+
if cached_chunks:
|
| 45 |
+
_vector_store = utils.create_vector_store(cached_chunks)
|
| 46 |
+
logger.info("✅ Vector store created from cached chunks")
|
| 47 |
+
else:
|
| 48 |
+
logger.warning("⚠️ No data available to build a vector store. Retrievers may not function until data is provided.")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"Error preparing vector store: {str(e)}")
|
| 51 |
+
raise
|
| 52 |
+
|
| 53 |
+
# Load merged chunks for BM25 (includes previous + new)
|
| 54 |
+
try:
|
| 55 |
+
logger.info("📦 Loading chunks cache for BM25 retriever...")
|
| 56 |
+
_chunks = utils.load_chunks() or []
|
| 57 |
+
if not _chunks:
|
| 58 |
+
logger.warning("⚠️ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Error loading chunks: {str(e)}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
# Create vector retriever
|
| 64 |
+
logger.info("🔍 Creating vector retriever...")
|
| 65 |
+
_vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None
|
| 66 |
+
|
| 67 |
+
# Create BM25 retriever
|
| 68 |
+
logger.info("📝 Creating BM25 retriever...")
|
| 69 |
+
_bm25_retriever = BM25Retriever.from_documents(_chunks) if _chunks else None
|
| 70 |
+
if _bm25_retriever:
|
| 71 |
+
_bm25_retriever.k = 5
|
| 72 |
+
|
| 73 |
+
# Create hybrid retriever
|
| 74 |
+
logger.info("🔄 Creating hybrid retriever...")
|
| 75 |
+
if _vector_retriever and _bm25_retriever:
|
| 76 |
+
_hybrid_retriever = EnsembleRetriever(
|
| 77 |
+
retrievers=[_bm25_retriever, _vector_retriever],
|
| 78 |
+
weights=[0.2, 0.8]
|
| 79 |
+
)
|
| 80 |
+
elif _vector_retriever:
|
| 81 |
+
logger.warning("ℹ️ BM25 retriever unavailable; using vector retriever only.")
|
| 82 |
+
_hybrid_retriever = _vector_retriever
|
| 83 |
+
elif _bm25_retriever:
|
| 84 |
+
_hybrid_retriever = _bm25_retriever
|
| 85 |
+
else:
|
| 86 |
+
raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")
|
| 87 |
+
|
| 88 |
+
_initialized = True
|
| 89 |
+
logger.info("✅ Retrievers initialized successfully.")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def initialize_eagerly():
|
| 93 |
+
"""Force initialization of retrievers for background loading"""
|
| 94 |
+
_ensure_initialized()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def is_initialized() -> bool:
|
| 98 |
+
"""Check if retrievers are already initialized"""
|
| 99 |
+
return _initialized
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# -----------------------------------------------
|
| 103 |
+
# Provider-aware retrieval helper functions
|
| 104 |
+
# -----------------------------------------------
|
| 105 |
+
_retrieval_pool = ThreadPoolExecutor(max_workers=4)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _get_doc_id(doc: Document) -> str:
|
| 109 |
+
"""Generate unique identifier for a document."""
|
| 110 |
+
source = doc.metadata.get('source', 'unknown')
|
| 111 |
+
page = doc.metadata.get('page_number', 'unknown')
|
| 112 |
+
content_hash = hash(doc.page_content[:200]) # Hash first 200 chars
|
| 113 |
+
return f"{source}_{page}_{content_hash}"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _match_provider(doc, provider: str) -> bool:
|
| 117 |
+
if not provider:
|
| 118 |
+
return True
|
| 119 |
+
prov = str(doc.metadata.get("provider", "")).strip().lower()
|
| 120 |
+
return prov == provider.strip().lower()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@traceable(name="VectorRetriever")
|
| 124 |
+
def vector_search(query: str, provider: str | None = None, k: int = None):
|
| 125 |
+
"""Search FAISS vector store with optional provider metadata filter."""
|
| 126 |
+
_ensure_initialized()
|
| 127 |
+
if not _vector_store:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
# Use global default if k is not specified
|
| 131 |
+
if k is None:
|
| 132 |
+
k = DEFAULT_K_VECTOR
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
# Standard search
|
| 136 |
+
if provider:
|
| 137 |
+
docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
|
| 138 |
+
else:
|
| 139 |
+
docs = _vector_store.similarity_search(query, k=k)
|
| 140 |
+
|
| 141 |
+
# Ensure provider post-filter in case backend filter is lenient
|
| 142 |
+
if provider:
|
| 143 |
+
docs = [d for d in docs if _match_provider(d, provider)]
|
| 144 |
+
return docs
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Vector search failed: {e}")
|
| 147 |
+
return []
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@traceable(name="BM25Retriever")
|
| 151 |
+
def bm25_search(query: str, provider: str | None = None, k: int = None):
|
| 152 |
+
"""Search BM25 using the global retriever with optional provider filter."""
|
| 153 |
+
_ensure_initialized()
|
| 154 |
+
|
| 155 |
+
# Use global default if k is not specified
|
| 156 |
+
if k is None:
|
| 157 |
+
k = DEFAULT_K_BM25
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
if not _bm25_retriever:
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
# Standard search
|
| 164 |
+
_bm25_retriever.k = max(1, k)
|
| 165 |
+
docs = _bm25_retriever.invoke(query) or []
|
| 166 |
+
|
| 167 |
+
if provider:
|
| 168 |
+
docs = [d for d in docs if _match_provider(d, provider)]
|
| 169 |
+
return docs[:k]
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"BM25 search failed: {e}")
|
| 172 |
+
return []
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None):
|
| 176 |
+
"""Combine vector and BM25 results (provider-filtered if provided)."""
|
| 177 |
+
_ensure_initialized() # Ensure retrievers are initialized before parallel execution
|
| 178 |
+
|
| 179 |
+
# Use global defaults if not specified
|
| 180 |
+
if k_vector is None:
|
| 181 |
+
k_vector = DEFAULT_K_VECTOR
|
| 182 |
+
if k_bm25 is None:
|
| 183 |
+
k_bm25 = DEFAULT_K_BM25
|
| 184 |
+
|
| 185 |
+
f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
|
| 186 |
+
f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)
|
| 187 |
+
|
| 188 |
+
v_docs = f_vector.result()
|
| 189 |
+
b_docs = f_bm25.result()
|
| 190 |
+
# Merge uniquely by document ID
|
| 191 |
+
seen = set()
|
| 192 |
+
merged = []
|
| 193 |
+
for d in v_docs + b_docs:
|
| 194 |
+
doc_id = _get_doc_id(d)
|
| 195 |
+
if doc_id not in seen:
|
| 196 |
+
seen.add(doc_id)
|
| 197 |
+
merged.append(d)
|
| 198 |
+
|
| 199 |
+
logger.info(f"Hybrid search returned {len(merged)} unique documents")
|
| 200 |
+
return merged
|
core/text_parser.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Parser Module
|
| 3 |
+
Parses free-form text input to extract structured patient data
|
| 4 |
+
"""
|
| 5 |
+
import re
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, Any, Optional
|
| 8 |
+
from .config import get_llm
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_patient_text(text_input: str) -> Dict[str, Any]:
|
| 14 |
+
"""
|
| 15 |
+
Parse free-form text input to extract structured patient data
|
| 16 |
+
using LLM-based extraction
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
text_input: Free-form text containing patient data
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dictionary containing structured patient data matching HBVPatientInput schema
|
| 23 |
+
"""
|
| 24 |
+
try:
|
| 25 |
+
# Create prompt for LLM to extract structured data
|
| 26 |
+
extraction_prompt = f"""You are a medical data extraction system. Extract structured patient data from the following free-form text.
|
| 27 |
+
|
| 28 |
+
PATIENT TEXT:
|
| 29 |
+
{text_input}
|
| 30 |
+
|
| 31 |
+
Extract the following information and return it as a JSON object. If a field is not mentioned, use reasonable defaults:
|
| 32 |
+
|
| 33 |
+
Required fields:
|
| 34 |
+
- sex: "Male" or "Female"
|
| 35 |
+
- age: integer (years)
|
| 36 |
+
- pregnancy_status: "Not pregnant" or "Pregnant"
|
| 37 |
+
- hbsag_status: "Positive" or "Negative"
|
| 38 |
+
- duration_hbsag_months: integer (months HBsAg has been positive)
|
| 39 |
+
- hbv_dna_level: float (IU/mL)
|
| 40 |
+
- hbeag_status: "Positive" or "Negative"
|
| 41 |
+
- alt_level: float (U/L)
|
| 42 |
+
- fibrosis_stage: "F0-F1", "F2-F3", or "F4"
|
| 43 |
+
- necroinflammatory_activity: "A0", "A1", "A2", or "A3"
|
| 44 |
+
- extrahepatic_manifestations: true or false
|
| 45 |
+
- immunosuppression_status: "None", "Chemotherapy", or "Other"
|
| 46 |
+
- coinfections: array of strings (e.g., ["HIV"], ["HCV"], ["HDV"], or [])
|
| 47 |
+
- family_history_cirrhosis_hcc: true or false
|
| 48 |
+
- other_comorbidities: array of strings or null
|
| 49 |
+
|
| 50 |
+
IMPORTANT EXTRACTION RULES:
|
| 51 |
+
1. For sex: Look for "male", "female", "man", "woman", etc.
|
| 52 |
+
2. For age: Extract the number followed by "year" or "years old"
|
| 53 |
+
3. For pregnancy_status: Default to "Not pregnant" unless explicitly mentioned
|
| 54 |
+
4. For HBsAg status: Look for "HBsAg positive" or "HBsAg negative"
|
| 55 |
+
5. For duration_hbsag_months: Look for duration in months or years (convert years to months)
|
| 56 |
+
6. For HBV DNA: Look for numbers followed by "IU/mL" or "IU/ml"
|
| 57 |
+
7. For HBeAg: Look for "HBeAg positive" or "HBeAg negative"
|
| 58 |
+
8. For ALT: Look for "ALT" followed by number and "U/L"
|
| 59 |
+
9. For fibrosis: Look for "F0", "F1", "F2", "F3", "F4" or descriptions like "significant fibrosis", "cirrhosis"
|
| 60 |
+
10. For necroinflammatory: Look for "A0", "A1", "A2", "A3"
|
| 61 |
+
11. For extrahepatic manifestations: Look for mentions of extrahepatic conditions
|
| 62 |
+
12. For immunosuppression: Look for "immunosuppression", "chemotherapy", etc.
|
| 63 |
+
13. For coinfections: Look for "HIV", "HCV", "HDV"
|
| 64 |
+
14. For family history: Look for "family history" of "cirrhosis" or "HCC"
|
| 65 |
+
|
| 66 |
+
DEFAULT VALUES (use if not mentioned):
|
| 67 |
+
- pregnancy_status: "Not pregnant"
|
| 68 |
+
- immunosuppression_status: "None"
|
| 69 |
+
- coinfections: []
|
| 70 |
+
- extrahepatic_manifestations: false
|
| 71 |
+
- family_history_cirrhosis_hcc: false
|
| 72 |
+
- other_comorbidities: null
|
| 73 |
+
|
| 74 |
+
Return ONLY a valid JSON object with the extracted data. Do not include any explanatory text.
|
| 75 |
+
|
| 76 |
+
Example format:
|
| 77 |
+
{{
|
| 78 |
+
"sex": "Male",
|
| 79 |
+
"age": 45,
|
| 80 |
+
"pregnancy_status": "Not pregnant",
|
| 81 |
+
"hbsag_status": "Positive",
|
| 82 |
+
"duration_hbsag_months": 12,
|
| 83 |
+
"hbv_dna_level": 5000.0,
|
| 84 |
+
"hbeag_status": "Positive",
|
| 85 |
+
"alt_level": 80.0,
|
| 86 |
+
"fibrosis_stage": "F2-F3",
|
| 87 |
+
"necroinflammatory_activity": "A2",
|
| 88 |
+
"extrahepatic_manifestations": false,
|
| 89 |
+
"immunosuppression_status": "None",
|
| 90 |
+
"coinfections": [],
|
| 91 |
+
"family_history_cirrhosis_hcc": false,
|
| 92 |
+
"other_comorbidities": null
|
| 93 |
+
}}
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
# Get LLM response
|
| 97 |
+
llm = get_llm()
|
| 98 |
+
logger.info("Sending text extraction prompt to LLM...")
|
| 99 |
+
response = llm.invoke(extraction_prompt)
|
| 100 |
+
logger.info("LLM response received for text extraction")
|
| 101 |
+
|
| 102 |
+
# Extract JSON from response
|
| 103 |
+
response_text = response.content if hasattr(response, 'content') else str(response)
|
| 104 |
+
|
| 105 |
+
# Log the response
|
| 106 |
+
logger.info(f"Text extraction response: {response_text}")
|
| 107 |
+
|
| 108 |
+
# Try to parse JSON from response
|
| 109 |
+
import json
|
| 110 |
+
try:
|
| 111 |
+
# Find JSON in response
|
| 112 |
+
json_start = response_text.find('{')
|
| 113 |
+
json_end = response_text.rfind('}') + 1
|
| 114 |
+
if json_start >= 0 and json_end > json_start:
|
| 115 |
+
json_str = response_text[json_start:json_end]
|
| 116 |
+
patient_data = json.loads(json_str)
|
| 117 |
+
logger.info(f"✅ Successfully extracted patient data from text")
|
| 118 |
+
return patient_data
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError("No JSON found in response")
|
| 121 |
+
|
| 122 |
+
except (json.JSONDecodeError, ValueError) as e:
|
| 123 |
+
logger.error(f"Failed to parse LLM response as JSON: {e}")
|
| 124 |
+
logger.error(f"Response text: {response_text}")
|
| 125 |
+
raise ValueError(f"Failed to extract structured data from text: {str(e)}")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Error in parse_patient_text: {str(e)}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def validate_extracted_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
| 133 |
+
"""
|
| 134 |
+
Validate and clean extracted patient data
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
data: Extracted patient data dictionary
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Validated and cleaned patient data
|
| 141 |
+
"""
|
| 142 |
+
# Ensure required fields are present
|
| 143 |
+
required_fields = [
|
| 144 |
+
'sex', 'age', 'pregnancy_status', 'hbsag_status',
|
| 145 |
+
'duration_hbsag_months', 'hbv_dna_level', 'hbeag_status',
|
| 146 |
+
'alt_level', 'fibrosis_stage', 'necroinflammatory_activity',
|
| 147 |
+
'extrahepatic_manifestations', 'family_history_cirrhosis_hcc'
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
for field in required_fields:
|
| 151 |
+
if field not in data:
|
| 152 |
+
raise ValueError(f"Missing required field: {field}")
|
| 153 |
+
|
| 154 |
+
# Set defaults for optional fields
|
| 155 |
+
if 'immunosuppression_status' not in data or not data['immunosuppression_status']:
|
| 156 |
+
data['immunosuppression_status'] = 'None'
|
| 157 |
+
|
| 158 |
+
if 'coinfections' not in data:
|
| 159 |
+
data['coinfections'] = []
|
| 160 |
+
|
| 161 |
+
if 'other_comorbidities' not in data:
|
| 162 |
+
data['other_comorbidities'] = None
|
| 163 |
+
|
| 164 |
+
# Validate data types and values
|
| 165 |
+
try:
|
| 166 |
+
data['age'] = int(data['age'])
|
| 167 |
+
data['duration_hbsag_months'] = int(data['duration_hbsag_months'])
|
| 168 |
+
data['hbv_dna_level'] = float(data['hbv_dna_level'])
|
| 169 |
+
data['alt_level'] = float(data['alt_level'])
|
| 170 |
+
data['extrahepatic_manifestations'] = bool(data['extrahepatic_manifestations'])
|
| 171 |
+
data['family_history_cirrhosis_hcc'] = bool(data['family_history_cirrhosis_hcc'])
|
| 172 |
+
except (ValueError, TypeError) as e:
|
| 173 |
+
raise ValueError(f"Invalid data type in extracted data: {str(e)}")
|
| 174 |
+
|
| 175 |
+
# Validate enum values
|
| 176 |
+
if data['sex'] not in ['Male', 'Female']:
|
| 177 |
+
raise ValueError(f"Invalid sex value: {data['sex']}")
|
| 178 |
+
|
| 179 |
+
if data['pregnancy_status'] not in ['Not pregnant', 'Pregnant']:
|
| 180 |
+
raise ValueError(f"Invalid pregnancy_status value: {data['pregnancy_status']}")
|
| 181 |
+
|
| 182 |
+
if data['hbsag_status'] not in ['Positive', 'Negative']:
|
| 183 |
+
raise ValueError(f"Invalid hbsag_status value: {data['hbsag_status']}")
|
| 184 |
+
|
| 185 |
+
if data['hbeag_status'] not in ['Positive', 'Negative']:
|
| 186 |
+
raise ValueError(f"Invalid hbeag_status value: {data['hbeag_status']}")
|
| 187 |
+
|
| 188 |
+
if data['fibrosis_stage'] not in ['F0-F1', 'F2-F3', 'F4']:
|
| 189 |
+
raise ValueError(f"Invalid fibrosis_stage value: {data['fibrosis_stage']}")
|
| 190 |
+
|
| 191 |
+
if data['necroinflammatory_activity'] not in ['A0', 'A1', 'A2', 'A3']:
|
| 192 |
+
raise ValueError(f"Invalid necroinflammatory_activity value: {data['necroinflammatory_activity']}")
|
| 193 |
+
|
| 194 |
+
return data
|
core/text_processors.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.text_splitter import (
|
| 2 |
+
RecursiveCharacterTextSplitter,
|
| 3 |
+
MarkdownHeaderTextSplitter
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
recursive_splitter = RecursiveCharacterTextSplitter(
|
| 7 |
+
chunk_size=3500,
|
| 8 |
+
chunk_overlap=400,
|
| 9 |
+
length_function=len,
|
| 10 |
+
separators=["\n\n", "\n", ". ", " ", ""],
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
markdown_splitter = MarkdownHeaderTextSplitter(
|
| 15 |
+
headers_to_split_on=[
|
| 16 |
+
("##", "Header 2"),
|
| 17 |
+
("###", "Header 3"),
|
| 18 |
+
],
|
| 19 |
+
strip_headers=False,
|
| 20 |
+
)
|
core/tools.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
|
| 7 |
+
import pytz
|
| 8 |
+
from langchain.schema import Document
|
| 9 |
+
from langchain.tools import tool
|
| 10 |
+
from .retrievers import hybrid_search
|
| 11 |
+
from .context_enrichment import enrich_retrieved_documents
|
| 12 |
+
from .config import logger
|
| 13 |
+
|
| 14 |
+
# Canonical provider names - For HBV: SASLT only
|
| 15 |
+
CANONICAL_PROVIDERS = ["SASLT"]
|
| 16 |
+
|
| 17 |
+
# Global configuration for medical_guidelines_knowledge_tool retrieval and enrichment
|
| 18 |
+
TOOL_K_VECTOR = 5 # Number of documents to retrieve using vector search (per provider)
|
| 19 |
+
TOOL_K_BM25 = 2 # Number of documents to retrieve using BM25 search (per provider)
|
| 20 |
+
TOOL_PAGES_BEFORE = 1 # Number of pages to include before each top result
|
| 21 |
+
TOOL_PAGES_AFTER = 1 # Number of pages to include after each top result
|
| 22 |
+
TOOL_MAX_ENRICHED = 2 # Maximum number of top documents to enrich with context (per provider)
|
| 23 |
+
|
| 24 |
+
# Global variables to store context for validation
|
| 25 |
+
_last_question = None # Stores the tool query
|
| 26 |
+
_last_documents = None
|
| 27 |
+
|
| 28 |
+
TOOL_MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4)))
|
| 29 |
+
_tool_executor = ThreadPoolExecutor(max_workers=TOOL_MAX_WORKERS)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Map lowercase variants and full names to canonical provider codes
|
| 33 |
+
_PROVIDER_ALIASES = {
|
| 34 |
+
"saslt": "SASLT",
|
| 35 |
+
"saslt 2021": "SASLT",
|
| 36 |
+
"saudi association for the study of liver diseases and transplantation": "SASLT",
|
| 37 |
+
"saslt guidelines": "SASLT",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _normalize_provider(provider: Optional[str], query: str) -> Optional[str]:
|
| 42 |
+
"""Normalize provider name from explicit parameter or query text."""
|
| 43 |
+
text = provider if provider else query
|
| 44 |
+
if not text:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
t = text.lower()
|
| 48 |
+
|
| 49 |
+
# Quick direct hits for canonical providers
|
| 50 |
+
for canon in CANONICAL_PROVIDERS:
|
| 51 |
+
if re.search(rf"\b{re.escape(canon.lower())}\b", t):
|
| 52 |
+
return canon
|
| 53 |
+
|
| 54 |
+
# Alias-based detection
|
| 55 |
+
for alias, canon in _PROVIDER_ALIASES.items():
|
| 56 |
+
if alias in t:
|
| 57 |
+
return canon
|
| 58 |
+
|
| 59 |
+
# If explicit provider didn't match, try query text as fallback
|
| 60 |
+
if provider and provider != query:
|
| 61 |
+
return _normalize_provider(None, query)
|
| 62 |
+
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def clear_text(text: str) -> str:
|
| 67 |
+
"""Clean and normalize text by removing markdown and excess whitespace."""
|
| 68 |
+
if not text:
|
| 69 |
+
return ""
|
| 70 |
+
t = text
|
| 71 |
+
# Normalize newlines
|
| 72 |
+
t = t.replace("\r\n", "\n").replace("\r", "\n")
|
| 73 |
+
# Links: keep title and URL
|
| 74 |
+
t = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1 (\2)", t)
|
| 75 |
+
# Images: drop entirely
|
| 76 |
+
t = re.sub(r"!\[[^\]]*\]\([^)]*\)", "", t)
|
| 77 |
+
# Remove headers/quotes markers at line starts
|
| 78 |
+
t = re.sub(r"(?m)^[>\s]*#{1,6}\s*", "", t)
|
| 79 |
+
# Remove backticks/code fences and emphasis
|
| 80 |
+
t = t.replace("```", "").replace("`", "")
|
| 81 |
+
t = t.replace("**", "").replace("*", "").replace("_", "")
|
| 82 |
+
# Collapse spaces before newlines
|
| 83 |
+
t = re.sub(r"[ \t]+\n", "\n", t)
|
| 84 |
+
# Collapse multiple newlines and spaces
|
| 85 |
+
t = re.sub(r"\n{3,}", "\n\n", t)
|
| 86 |
+
t = re.sub(r"[ \t]{2,}", " ", t)
|
| 87 |
+
# Trim and truncate
|
| 88 |
+
t = t.strip()
|
| 89 |
+
return t
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _format_docs_with_citations(docs: List[Document], group_by_provider: bool = False) -> str:
|
| 93 |
+
"""Format documents with citations."""
|
| 94 |
+
if not docs:
|
| 95 |
+
return "No results."
|
| 96 |
+
|
| 97 |
+
if group_by_provider:
|
| 98 |
+
return _format_grouped_by_provider(docs)
|
| 99 |
+
|
| 100 |
+
parts = []
|
| 101 |
+
for i, d in enumerate(docs, start=1):
|
| 102 |
+
meta = d.metadata or {}
|
| 103 |
+
citation = _build_citation(i, meta, d.page_content)
|
| 104 |
+
parts.append(citation)
|
| 105 |
+
|
| 106 |
+
return "\n\n".join(parts)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _build_citation(index: int, metadata: dict, content: str, include_provider: bool = True) -> str:
|
| 110 |
+
"""Build a single citation string with clean formatting."""
|
| 111 |
+
source = metadata.get("source", "unknown")
|
| 112 |
+
page = metadata.get("page_number", "?")
|
| 113 |
+
provider = metadata.get("provider", "unknown")
|
| 114 |
+
disease = metadata.get("disease", "unknown")
|
| 115 |
+
is_context = metadata.get("context_enrichment", False)
|
| 116 |
+
|
| 117 |
+
snippet = clear_text(content)
|
| 118 |
+
|
| 119 |
+
# Build citation header
|
| 120 |
+
citation = f"📄 Result {index}:\n"
|
| 121 |
+
|
| 122 |
+
# Build metadata line
|
| 123 |
+
metadata_parts = []
|
| 124 |
+
if include_provider:
|
| 125 |
+
metadata_parts.append(f"Provider: {provider}")
|
| 126 |
+
metadata_parts.append(f"Disease: {disease}")
|
| 127 |
+
metadata_parts.append(f"Source: {source}")
|
| 128 |
+
metadata_parts.append(f"Page: {page}")
|
| 129 |
+
|
| 130 |
+
citation += " | ".join(metadata_parts)
|
| 131 |
+
|
| 132 |
+
if is_context:
|
| 133 |
+
citation += " [CONTEXT PAGE]"
|
| 134 |
+
|
| 135 |
+
citation += f"\n\n{snippet}\n"
|
| 136 |
+
return citation
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _document_to_dict(doc: Document) -> dict:
|
| 140 |
+
"""Convert a Document to a dictionary for storage."""
|
| 141 |
+
return {
|
| 142 |
+
"doc_id": getattr(doc, 'id', None),
|
| 143 |
+
"source": doc.metadata.get("source", "unknown"),
|
| 144 |
+
"provider": doc.metadata.get("provider", "unknown"),
|
| 145 |
+
"page_number": doc.metadata.get("page_number", "unknown"),
|
| 146 |
+
"disease": doc.metadata.get("disease", "unknown"),
|
| 147 |
+
"context_enrichment": doc.metadata.get("context_enrichment", False),
|
| 148 |
+
"enriched": doc.metadata.get("enriched", False),
|
| 149 |
+
"pages_included": doc.metadata.get("pages_included", []),
|
| 150 |
+
"primary_page": doc.metadata.get("primary_page"),
|
| 151 |
+
"context_pages_before": doc.metadata.get("context_pages_before"),
|
| 152 |
+
"context_pages_after": doc.metadata.get("context_pages_after"),
|
| 153 |
+
"content": doc.page_content
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _format_grouped_by_provider(docs: List[Document]) -> str:
|
| 158 |
+
"""Format results grouped by provider for multi-provider queries."""
|
| 159 |
+
if not docs:
|
| 160 |
+
return "No results found from any guideline provider."
|
| 161 |
+
|
| 162 |
+
# Group documents by provider
|
| 163 |
+
provider_groups = {}
|
| 164 |
+
for doc in docs:
|
| 165 |
+
provider = doc.metadata.get("provider", "unknown")
|
| 166 |
+
if provider not in provider_groups:
|
| 167 |
+
provider_groups[provider] = []
|
| 168 |
+
provider_groups[provider].append(doc)
|
| 169 |
+
|
| 170 |
+
# Format header
|
| 171 |
+
parts = [
|
| 172 |
+
f"\n{'='*70}",
|
| 173 |
+
f"SEARCH RESULTS FROM SASLT 2021 GUIDELINES",
|
| 174 |
+
f"Retrieved information from {len(provider_groups)} guideline provider(s)",
|
| 175 |
+
f"{'='*70}\n"
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
# Format each provider's results
|
| 179 |
+
for idx, provider in enumerate(sorted(provider_groups.keys()), start=1):
|
| 180 |
+
provider_docs = provider_groups[provider]
|
| 181 |
+
|
| 182 |
+
# Provider header
|
| 183 |
+
parts.append(f"\n{'─'*70}")
|
| 184 |
+
parts.append(f"🏥 PROVIDER {idx}: {provider} ({len(provider_docs)} result{'s' if len(provider_docs) != 1 else ''})")
|
| 185 |
+
parts.append(f"{'─'*70}\n")
|
| 186 |
+
|
| 187 |
+
# Format each document for this provider
|
| 188 |
+
for i, doc in enumerate(provider_docs, start=1):
|
| 189 |
+
meta = doc.metadata or {}
|
| 190 |
+
citation = _build_citation(i, meta, doc.page_content, include_provider=False)
|
| 191 |
+
parts.append(citation)
|
| 192 |
+
|
| 193 |
+
if i < len(provider_docs):
|
| 194 |
+
parts.append("")
|
| 195 |
+
|
| 196 |
+
return "\n".join(parts)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@tool
|
| 200 |
+
def medical_guidelines_knowledge_tool(query: str, provider: Optional[str] = None) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Retrieve comprehensive medical guideline knowledge with enriched context from SASLT 2021 guidelines.
|
| 203 |
+
Includes surrounding pages (before/after) for top results to provide complete clinical context.
|
| 204 |
+
|
| 205 |
+
This retrieves information from SASLT 2021 guidelines for HBV management.
|
| 206 |
+
|
| 207 |
+
Returns detailed text with full metadata and contextual information for expert analysis.
|
| 208 |
+
"""
|
| 209 |
+
global _last_question, _last_documents
|
| 210 |
+
try:
|
| 211 |
+
# Store question for validation context
|
| 212 |
+
_last_question = query
|
| 213 |
+
|
| 214 |
+
# Normalize provider name from either explicit arg or query text
|
| 215 |
+
normalized_provider = _normalize_provider(provider, query)
|
| 216 |
+
|
| 217 |
+
# Query SASLT provider
|
| 218 |
+
if not normalized_provider:
|
| 219 |
+
logger.info("No specific provider - querying SASLT")
|
| 220 |
+
normalized_provider = "SASLT"
|
| 221 |
+
|
| 222 |
+
# Perform hybrid search
|
| 223 |
+
docs = hybrid_search(query, normalized_provider, TOOL_K_VECTOR, TOOL_K_BM25)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Store documents for validation context
|
| 227 |
+
_last_documents = [_document_to_dict(doc) for doc in docs]
|
| 228 |
+
|
| 229 |
+
return _format_docs_with_citations(docs)
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Retrieval error: {str(e)}")
|
| 232 |
+
return f"Retrieval error: {str(e)}"
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@tool
|
| 236 |
+
def get_current_datetime_tool() -> str:
|
| 237 |
+
"""
|
| 238 |
+
Returns the current date, time, and day of the week for Egypt (Africa/Cairo).
|
| 239 |
+
This is the only reliable source for date and time information. Use this tool
|
| 240 |
+
whenever a user asks about 'today', 'now', or any other time-sensitive query.
|
| 241 |
+
The output is always in English and in standard 12-hour format.
|
| 242 |
+
"""
|
| 243 |
+
try:
|
| 244 |
+
# Define the timezone for Egypt
|
| 245 |
+
egypt_tz = pytz.timezone('Africa/Cairo')
|
| 246 |
+
|
| 247 |
+
# Get the current time in that timezone
|
| 248 |
+
now_egypt = datetime.now(egypt_tz)
|
| 249 |
+
|
| 250 |
+
# Manual mapping to ensure English output regardless of system locale
|
| 251 |
+
days_en = {
|
| 252 |
+
0: "Monday", 1: "Tuesday", 2: "Wednesday", 3: "Thursday",
|
| 253 |
+
4: "Friday", 5: "Saturday", 6: "Sunday"
|
| 254 |
+
}
|
| 255 |
+
months_en = {
|
| 256 |
+
1: "January", 2: "February", 3: "March", 4: "April",
|
| 257 |
+
5: "May", 6: "June", 7: "July", 8: "August",
|
| 258 |
+
9: "September", 10: "October", 11: "November", 12: "December"
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
# Get English names using manual mapping
|
| 262 |
+
day_name = days_en[now_egypt.weekday()]
|
| 263 |
+
month_name = months_en[now_egypt.month]
|
| 264 |
+
day = now_egypt.day
|
| 265 |
+
year = now_egypt.year
|
| 266 |
+
|
| 267 |
+
# Format time manually to avoid locale issues
|
| 268 |
+
hour = now_egypt.hour
|
| 269 |
+
minute = now_egypt.minute
|
| 270 |
+
|
| 271 |
+
# Convert to 12-hour format
|
| 272 |
+
if hour == 0:
|
| 273 |
+
hour_12 = 12
|
| 274 |
+
period = "AM"
|
| 275 |
+
elif hour < 12:
|
| 276 |
+
hour_12 = hour
|
| 277 |
+
period = "AM"
|
| 278 |
+
elif hour == 12:
|
| 279 |
+
hour_12 = 12
|
| 280 |
+
period = "PM"
|
| 281 |
+
else:
|
| 282 |
+
hour_12 = hour - 12
|
| 283 |
+
period = "PM"
|
| 284 |
+
|
| 285 |
+
time_str = f"{hour_12:02d}:{minute:02d} {period}"
|
| 286 |
+
|
| 287 |
+
# Create the final string
|
| 288 |
+
return f"Current date and time in Egypt: {day_name}, {month_name} {day}, {year} at {time_str}"
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
return f"Error getting current datetime: {str(e)}"
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
core/tracing.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import contextvars
|
| 4 |
+
from typing import Optional, Dict, Any
|
| 5 |
+
|
| 6 |
+
# LangSmith for Tracing and Monitoring
|
| 7 |
+
try:
|
| 8 |
+
from langsmith import traceable, Client
|
| 9 |
+
from langsmith import trace
|
| 10 |
+
except Exception as e:
|
| 11 |
+
# Provide fallbacks if langsmith is not installed to avoid runtime crashes
|
| 12 |
+
traceable = lambda *args, **kwargs: (lambda f: f)
|
| 13 |
+
class _Noop:
|
| 14 |
+
def __call__(self, *args, **kwargs):
|
| 15 |
+
class _Ctx:
|
| 16 |
+
def __enter__(self):
|
| 17 |
+
class _Run:
|
| 18 |
+
id = None
|
| 19 |
+
return _Run()
|
| 20 |
+
def __exit__(self, exc_type, exc, tb):
|
| 21 |
+
return False
|
| 22 |
+
return _Ctx()
|
| 23 |
+
trace = _Noop()
|
| 24 |
+
Client = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# -----------------------------------------------------------------------------
|
| 28 |
+
# Environment Setup (non-destructive: set defaults if not present)
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
DEFAULT_LANGSMITH_API_KEY = "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"
|
| 31 |
+
DEFAULT_PROJECT = "medical_chatbot"
|
| 32 |
+
|
| 33 |
+
os.environ.setdefault("LANGSMITH_TRACING_V2", "true")
|
| 34 |
+
os.environ.setdefault("LANGSMITH_API_KEY", DEFAULT_LANGSMITH_API_KEY)
|
| 35 |
+
os.environ.setdefault("LANGCHAIN_PROJECT", DEFAULT_PROJECT) # tracing
|
| 36 |
+
os.environ.setdefault("LANGSMITH_PROJECT", DEFAULT_PROJECT) # feedback
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# -----------------------------------------------------------------------------
|
| 40 |
+
# LangSmith Client
|
| 41 |
+
# -----------------------------------------------------------------------------
|
| 42 |
+
_langsmith_client: Optional[Client] = None
|
| 43 |
+
try:
|
| 44 |
+
if Client is not None:
|
| 45 |
+
_langsmith_client = Client(
|
| 46 |
+
api_url="https://api.smith.langchain.com",
|
| 47 |
+
api_key=os.environ.get("LANGSMITH_API_KEY"),
|
| 48 |
+
)
|
| 49 |
+
except Exception:
|
| 50 |
+
_langsmith_client = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# -----------------------------------------------------------------------------
|
| 54 |
+
# Conversation Tracker for session metadata
|
| 55 |
+
# -----------------------------------------------------------------------------
|
| 56 |
+
class ConversationTracker:
|
| 57 |
+
def __init__(self) -> None:
|
| 58 |
+
self.session_id = str(uuid.uuid4())
|
| 59 |
+
self.conversation_count = 0
|
| 60 |
+
|
| 61 |
+
def start_new_session(self) -> None:
|
| 62 |
+
self.session_id = str(uuid.uuid4())
|
| 63 |
+
self.conversation_count = 0
|
| 64 |
+
|
| 65 |
+
def get_session_metadata(self, increment: bool = False) -> Dict[str, Any]:
|
| 66 |
+
if increment:
|
| 67 |
+
self.conversation_count += 1
|
| 68 |
+
return {
|
| 69 |
+
"session_id": self.session_id,
|
| 70 |
+
"conversation_count": self.conversation_count,
|
| 71 |
+
"application": os.environ.get("LANGSMITH_PROJECT", DEFAULT_PROJECT),
|
| 72 |
+
"version": "1.0",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
conversation_tracker = ConversationTracker()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -----------------------------------------------------------------------------
|
| 80 |
+
# Helper function to safely log feedback to LangSmith
|
| 81 |
+
# -----------------------------------------------------------------------------
|
| 82 |
+
def log_to_langsmith(key: str, value: dict, run_id: Optional[str] = None) -> None:
|
| 83 |
+
client = _langsmith_client
|
| 84 |
+
if not client:
|
| 85 |
+
return
|
| 86 |
+
try:
|
| 87 |
+
client.create_feedback(
|
| 88 |
+
run_id=run_id,
|
| 89 |
+
key=key,
|
| 90 |
+
value=value,
|
| 91 |
+
project=os.environ.get("LANGSMITH_PROJECT", DEFAULT_PROJECT),
|
| 92 |
+
)
|
| 93 |
+
except Exception:
|
| 94 |
+
# Swallow logging errors to avoid breaking the app
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
__all__ = [
|
| 99 |
+
"traceable",
|
| 100 |
+
"trace",
|
| 101 |
+
"Client",
|
| 102 |
+
"conversation_tracker",
|
| 103 |
+
"log_to_langsmith",
|
| 104 |
+
]
|
core/utils.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Iterable
|
| 9 |
+
from langchain.schema import Document
|
| 10 |
+
from langchain_community.vectorstores import FAISS
|
| 11 |
+
|
| 12 |
+
from .config import get_embedding_model, VECTOR_STORE_DIR, CHUNKS_PATH, NEW_DATA, PROCESSED_DATA
|
| 13 |
+
from .text_processors import markdown_splitter, recursive_splitter
|
| 14 |
+
from . import data_loaders
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
MAX_WORKERS = max(2, min(8, (os.cpu_count() or 4)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_vector_store() -> Optional[FAISS]:
|
| 22 |
+
"""Load existing vector store with proper error handling.
|
| 23 |
+
Only attempt to load if required FAISS files are present.
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
store_dir = Path(VECTOR_STORE_DIR)
|
| 27 |
+
index_file = store_dir / "index.faiss"
|
| 28 |
+
meta_file = store_dir / "index.pkl" # created by LangChain FAISS.save_local
|
| 29 |
+
|
| 30 |
+
# If directory exists but files are missing, do not attempt load
|
| 31 |
+
if not (index_file.exists() and meta_file.exists()):
|
| 32 |
+
logger.info("Vector store not initialized yet; index files not found. Skipping load.")
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
vector_store = FAISS.load_local(
|
| 36 |
+
str(VECTOR_STORE_DIR),
|
| 37 |
+
get_embedding_model(),
|
| 38 |
+
allow_dangerous_deserialization=True,
|
| 39 |
+
)
|
| 40 |
+
logger.info("Successfully loaded existing vector store")
|
| 41 |
+
return vector_store
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.error(f"Failed to load vector store: {e}")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_chunks() -> Optional[List[Document]]:
|
| 48 |
+
"""Load pre-processed document chunks from cache with error handling."""
|
| 49 |
+
try:
|
| 50 |
+
if Path(CHUNKS_PATH).exists():
|
| 51 |
+
with open(CHUNKS_PATH, 'rb') as f:
|
| 52 |
+
chunks = pickle.load(f)
|
| 53 |
+
logger.info(f"Successfully loaded {len(chunks)} chunks from cache")
|
| 54 |
+
return chunks
|
| 55 |
+
else:
|
| 56 |
+
logger.info("No cached chunks found")
|
| 57 |
+
return None
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Failed to load chunks: {e}")
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_chunks(chunks: List[Document]) -> bool:
|
| 64 |
+
"""Save processed document chunks to cache file.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
chunks: List of document chunks to save
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
True if successful, False otherwise
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
# Ensure directory exists
|
| 74 |
+
Path(CHUNKS_PATH).parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
with open(CHUNKS_PATH, 'wb') as f:
|
| 77 |
+
pickle.dump(chunks, f)
|
| 78 |
+
logger.info(f"Successfully saved {len(chunks)} chunks to {CHUNKS_PATH}")
|
| 79 |
+
return True
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Failed to save chunks: {e}")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ============================================================================
|
| 86 |
+
# DOCUMENT PROCESSING UTILITIES
|
| 87 |
+
# ============================================================================
|
| 88 |
+
|
| 89 |
+
def _iter_files(root: Path) -> Iterable[Path]:
|
| 90 |
+
"""Yield PDF and Markdown files under the given root directory recursively.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
root: Root directory to search
|
| 94 |
+
|
| 95 |
+
Yields:
|
| 96 |
+
Path objects for PDF and Markdown files
|
| 97 |
+
"""
|
| 98 |
+
if not root.exists():
|
| 99 |
+
return []
|
| 100 |
+
for p in root.rglob('*'):
|
| 101 |
+
if p.is_file() and p.suffix.lower() in {'.pdf', '.md'}:
|
| 102 |
+
yield p
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_documents() -> List[Document]:
|
| 106 |
+
"""Load documents from NEW_DATA directory.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List of loaded documents
|
| 110 |
+
|
| 111 |
+
Note:
|
| 112 |
+
Use create_documents_and_files() if you need both documents and file paths.
|
| 113 |
+
"""
|
| 114 |
+
docs, _ = create_documents_and_files()
|
| 115 |
+
return docs
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _load_documents_for_file(file_path: Path) -> List[Document]:
|
| 119 |
+
"""Load documents from a single file (PDF or Markdown).
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
file_path: Path to the file to load
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
List of documents loaded from the file
|
| 126 |
+
"""
|
| 127 |
+
try:
|
| 128 |
+
if file_path.suffix.lower() == '.pdf':
|
| 129 |
+
return data_loaders.load_pdf_documents(file_path)
|
| 130 |
+
return data_loaders.load_markdown_documents(file_path)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Failed to load {file_path}: {e}")
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def create_documents_and_files() -> tuple[List[Document], List[Path]]:
|
| 137 |
+
"""Load documents from NEW_DATA directory and return both documents and file paths.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Tuple of (documents, file_paths) where:
|
| 141 |
+
- documents: List of loaded Document objects
|
| 142 |
+
- file_paths: List of Path objects for files that were loaded
|
| 143 |
+
"""
|
| 144 |
+
documents: List[Document] = []
|
| 145 |
+
files = list(_iter_files(NEW_DATA))
|
| 146 |
+
if not files:
|
| 147 |
+
logger.info(f"No new files found under {NEW_DATA}")
|
| 148 |
+
return documents, []
|
| 149 |
+
|
| 150 |
+
worker_count = min(MAX_WORKERS, len(files)) or 1
|
| 151 |
+
with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
| 152 |
+
futures = {executor.submit(_load_documents_for_file, file_path): file_path for file_path in files}
|
| 153 |
+
for future in as_completed(futures):
|
| 154 |
+
documents.extend(future.result())
|
| 155 |
+
logger.info(f"Loaded {len(documents)} documents from {NEW_DATA}")
|
| 156 |
+
return documents, files
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _segment_document(doc: Document) -> List[Document]:
|
| 160 |
+
"""Segment a document using markdown headers if applicable.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
doc: Document to segment
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
List of segmented documents (or original if not markdown)
|
| 167 |
+
"""
|
| 168 |
+
source_name = str(doc.metadata.get("source", "")).lower()
|
| 169 |
+
if source_name.endswith('.md'):
|
| 170 |
+
try:
|
| 171 |
+
md_sections = markdown_splitter.split_text(doc.page_content)
|
| 172 |
+
return [Document(page_content=section.page_content, metadata={**doc.metadata, **section.metadata}) for section in md_sections]
|
| 173 |
+
except Exception:
|
| 174 |
+
return [doc]
|
| 175 |
+
return [doc]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _split_chunk(doc: Document) -> List[Document]:
|
| 179 |
+
"""Split a document into smaller chunks using recursive splitter.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
doc: Document to split
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
List of document chunks
|
| 186 |
+
"""
|
| 187 |
+
try:
|
| 188 |
+
return recursive_splitter.split_documents([doc])
|
| 189 |
+
except Exception as exc:
|
| 190 |
+
logger.error(f"Failed to split document {doc.metadata.get('source', 'unknown')}: {exc}")
|
| 191 |
+
return []
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def split_documents(documents: List[Document]) -> List[Document]:
|
| 195 |
+
"""Split documents into smaller chunks for vector store indexing.
|
| 196 |
+
|
| 197 |
+
Process:
|
| 198 |
+
1. Segment markdown files by headers (if applicable)
|
| 199 |
+
2. Split all documents into uniform chunks using recursive splitter
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
documents: List of documents to split
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
List of document chunks ready for indexing
|
| 206 |
+
"""
|
| 207 |
+
if not documents:
|
| 208 |
+
return []
|
| 209 |
+
|
| 210 |
+
# First pass: optional markdown header segmentation for .md sources
|
| 211 |
+
worker_count = min(MAX_WORKERS, len(documents)) or 1
|
| 212 |
+
with ThreadPoolExecutor(max_workers=worker_count) as executor:
|
| 213 |
+
segmented_lists = list(executor.map(_segment_document, documents))
|
| 214 |
+
segmented: List[Document] = [seg for sublist in segmented_lists for seg in sublist]
|
| 215 |
+
|
| 216 |
+
if not segmented:
|
| 217 |
+
return []
|
| 218 |
+
|
| 219 |
+
# Second pass: split into uniform chunks
|
| 220 |
+
split_worker_count = min(MAX_WORKERS, len(segmented)) or 1
|
| 221 |
+
with ThreadPoolExecutor(max_workers=split_worker_count) as executor:
|
| 222 |
+
chunk_lists = list(executor.map(_split_chunk, segmented))
|
| 223 |
+
|
| 224 |
+
chunks = [chunk for chunk_list in chunk_lists for chunk in chunk_list]
|
| 225 |
+
logger.info(f"Split {len(segmented)} documents into {len(chunks)} chunks")
|
| 226 |
+
return chunks
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def create_vector_store(chunks: List[Document]) -> FAISS:
|
| 230 |
+
"""Create a new FAISS vector store from document chunks and persist it.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
chunks: List of document chunks to index
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Created FAISS vector store
|
| 237 |
+
|
| 238 |
+
Raises:
|
| 239 |
+
ValueError: If chunks list is empty
|
| 240 |
+
"""
|
| 241 |
+
if not chunks:
|
| 242 |
+
raise ValueError("Cannot create vector store from empty chunks")
|
| 243 |
+
vector_store = FAISS.from_documents(chunks, get_embedding_model())
|
| 244 |
+
vector_store.save_local(str(VECTOR_STORE_DIR))
|
| 245 |
+
logger.info("Vector store created and saved")
|
| 246 |
+
return vector_store
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def update_vector_store_with_chunks(chunks: List[Document]) -> FAISS:
|
| 250 |
+
"""Update vector store with new chunks or create if doesn't exist.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
chunks: List of new document chunks to add
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Updated or newly created FAISS vector store
|
| 257 |
+
"""
|
| 258 |
+
if not chunks:
|
| 259 |
+
existing = load_vector_store()
|
| 260 |
+
if existing:
|
| 261 |
+
return existing
|
| 262 |
+
|
| 263 |
+
store = load_vector_store()
|
| 264 |
+
if store is None:
|
| 265 |
+
store = create_vector_store(chunks)
|
| 266 |
+
else:
|
| 267 |
+
# Add to existing store and persist
|
| 268 |
+
store.add_documents(chunks)
|
| 269 |
+
store.save_local(str(VECTOR_STORE_DIR))
|
| 270 |
+
logger.info(f"Added {len(chunks)} new chunks to existing vector store")
|
| 271 |
+
return store
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _move_to_processed(paths: List[Path]) -> None:
|
| 275 |
+
"""Move processed files to processed_data folder maintaining directory structure.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
paths: List of file paths to move
|
| 279 |
+
"""
|
| 280 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 281 |
+
|
| 282 |
+
for p in paths:
|
| 283 |
+
try:
|
| 284 |
+
if p.exists() and p.is_file():
|
| 285 |
+
# Calculate relative path from NEW_DATA
|
| 286 |
+
try:
|
| 287 |
+
rel_path = p.relative_to(NEW_DATA)
|
| 288 |
+
except ValueError:
|
| 289 |
+
# File is not under NEW_DATA, skip it
|
| 290 |
+
logger.warning(f"File {p} is not under NEW_DATA directory, skipping")
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# Create destination path in PROCESSED_DATA with same structure
|
| 294 |
+
dest_dir = PROCESSED_DATA / rel_path.parent
|
| 295 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 296 |
+
|
| 297 |
+
# Add timestamp to filename to avoid overwriting
|
| 298 |
+
dest_file = dest_dir / f"{p.stem}_{timestamp}{p.suffix}"
|
| 299 |
+
|
| 300 |
+
# Move the file
|
| 301 |
+
shutil.move(str(p), str(dest_file))
|
| 302 |
+
logger.info(f"📦 Moved processed file: {p.name} -> {dest_file.relative_to(PROCESSED_DATA)}")
|
| 303 |
+
except Exception as e:
|
| 304 |
+
logger.error(f"❌ Failed to move {p}: {e}")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _cleanup_empty_dirs(root: Path) -> None:
|
| 308 |
+
"""Remove empty directories under root directory (best-effort).
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
root: Root directory to clean up
|
| 312 |
+
"""
|
| 313 |
+
try:
|
| 314 |
+
# Walk bottom-up to remove empty directories
|
| 315 |
+
dirs = [d for d in root.rglob('*') if d.is_dir()]
|
| 316 |
+
for dirpath in sorted(dirs, key=lambda x: len(str(x)), reverse=True):
|
| 317 |
+
try:
|
| 318 |
+
if not any(dirpath.iterdir()):
|
| 319 |
+
dirpath.rmdir()
|
| 320 |
+
logger.info(f"Removed empty directory: {dirpath}")
|
| 321 |
+
except Exception:
|
| 322 |
+
pass
|
| 323 |
+
except Exception:
|
| 324 |
+
pass
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def process_new_data_and_update_vector_store() -> Optional[FAISS]:
|
| 328 |
+
"""Process new documents and update the vector store.
|
| 329 |
+
|
| 330 |
+
Workflow:
|
| 331 |
+
1. Load documents from NEW_DATA directory
|
| 332 |
+
2. Split documents into chunks
|
| 333 |
+
3. Update chunks cache and vector store
|
| 334 |
+
4. Delete processed files and clean up empty directories
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Updated FAISS vector store, or None if processing failed
|
| 338 |
+
"""
|
| 339 |
+
try:
|
| 340 |
+
docs, files = create_documents_and_files()
|
| 341 |
+
if not docs:
|
| 342 |
+
logger.info("No new documents to process.")
|
| 343 |
+
return load_vector_store()
|
| 344 |
+
|
| 345 |
+
chunks = split_documents(docs)
|
| 346 |
+
|
| 347 |
+
# Save/merge chunks first (durability)
|
| 348 |
+
existing_chunks = load_chunks() or []
|
| 349 |
+
merged_chunks = existing_chunks + chunks
|
| 350 |
+
|
| 351 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 352 |
+
save_future = executor.submit(save_chunks, merged_chunks)
|
| 353 |
+
store_future = executor.submit(update_vector_store_with_chunks, chunks)
|
| 354 |
+
save_success = save_future.result()
|
| 355 |
+
store = store_future.result()
|
| 356 |
+
|
| 357 |
+
if not save_success:
|
| 358 |
+
logger.warning("Chunk persistence reported failure; vector store was updated but cache may be stale.")
|
| 359 |
+
|
| 360 |
+
# If we reached here, store update succeeded; move processed source files
|
| 361 |
+
_move_to_processed(files)
|
| 362 |
+
_cleanup_empty_dirs(NEW_DATA)
|
| 363 |
+
|
| 364 |
+
logger.info(
|
| 365 |
+
f"✅ Processed {len(docs)} new documents into {len(chunks)} chunks, updated vector store, and moved files to processed_data."
|
| 366 |
+
)
|
| 367 |
+
return store
|
| 368 |
+
except Exception as e:
|
| 369 |
+
logger.error(f"Failed processing new data: {e}")
|
| 370 |
+
return None
|
core/validation.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import uuid
|
| 4 |
+
import traceback
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, List, Any, Optional
|
| 7 |
+
import pytz
|
| 8 |
+
from langchain_openai import ChatOpenAI
|
| 9 |
+
from langchain.schema import HumanMessage, SystemMessage
|
| 10 |
+
|
| 11 |
+
from .config import logger
|
| 12 |
+
from .github_storage import get_github_storage
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MedicalAnswerValidator:
|
| 16 |
+
"""
|
| 17 |
+
Medical answer validation system that evaluates responses using a separate LLM instance.
|
| 18 |
+
Produces structured JSON evaluations and saves them to evaluation_results.json.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""Initialize the validator with LLM and system prompt."""
|
| 23 |
+
self.validator_llm = self._create_validator_llm()
|
| 24 |
+
self.validation_system_prompt = self._create_validation_system_prompt()
|
| 25 |
+
self.evaluation_file = "evaluation_results.json"
|
| 26 |
+
logger.info("Medical answer validator initialized successfully")
|
| 27 |
+
|
| 28 |
+
def _get_next_interaction_id(self) -> str:
|
| 29 |
+
"""Get the next interaction ID by finding the highest existing ID and adding 1."""
|
| 30 |
+
try:
|
| 31 |
+
# Try to get from GitHub first
|
| 32 |
+
github_storage = get_github_storage()
|
| 33 |
+
existing_content = github_storage._get_file_content("medical_data/evaluation_results.json")
|
| 34 |
+
|
| 35 |
+
if existing_content:
|
| 36 |
+
try:
|
| 37 |
+
evaluations = json.loads(existing_content)
|
| 38 |
+
if evaluations and isinstance(evaluations, list):
|
| 39 |
+
logger.info(f"Found {len(evaluations)} existing evaluations in GitHub")
|
| 40 |
+
# Find the highest existing ID
|
| 41 |
+
max_id = 0
|
| 42 |
+
for eval_item in evaluations:
|
| 43 |
+
try:
|
| 44 |
+
current_id = int(eval_item.get("interaction_id", "0"))
|
| 45 |
+
max_id = max(max_id, current_id)
|
| 46 |
+
except (ValueError, TypeError):
|
| 47 |
+
continue
|
| 48 |
+
next_id = str(max_id + 1)
|
| 49 |
+
logger.info(f"Next interaction ID will be: {next_id}")
|
| 50 |
+
return next_id
|
| 51 |
+
except json.JSONDecodeError as e:
|
| 52 |
+
logger.warning(f"Failed to parse GitHub evaluation file: {e}")
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
# Fallback to local file check
|
| 56 |
+
if os.path.exists(self.evaluation_file):
|
| 57 |
+
logger.info("GitHub file not found, checking local file")
|
| 58 |
+
with open(self.evaluation_file, "r", encoding="utf-8") as f:
|
| 59 |
+
evaluations = json.load(f)
|
| 60 |
+
|
| 61 |
+
if evaluations:
|
| 62 |
+
logger.info(f"Found {len(evaluations)} existing evaluations in local file")
|
| 63 |
+
# Find the highest existing ID
|
| 64 |
+
max_id = 0
|
| 65 |
+
for eval_item in evaluations:
|
| 66 |
+
try:
|
| 67 |
+
current_id = int(eval_item.get("interaction_id", "0"))
|
| 68 |
+
max_id = max(max_id, current_id)
|
| 69 |
+
except (ValueError, TypeError):
|
| 70 |
+
continue
|
| 71 |
+
next_id = str(max_id + 1)
|
| 72 |
+
logger.info(f"Next interaction ID from local file: {next_id}")
|
| 73 |
+
return next_id
|
| 74 |
+
else:
|
| 75 |
+
logger.info("Local file is empty, starting with ID 1")
|
| 76 |
+
return "1"
|
| 77 |
+
else:
|
| 78 |
+
logger.info("No existing evaluation file found, starting with ID 1")
|
| 79 |
+
return "1"
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Error getting next interaction ID: {e}")
|
| 82 |
+
return "1"
|
| 83 |
+
|
| 84 |
+
def _clean_documents_for_storage(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 85 |
+
"""Clean documents by removing snippets and keeping only essential fields."""
|
| 86 |
+
cleaned_docs = []
|
| 87 |
+
for doc in documents:
|
| 88 |
+
is_context_page = doc.get("context_enrichment", False)
|
| 89 |
+
|
| 90 |
+
cleaned_doc = {
|
| 91 |
+
"doc_id": doc.get("doc_id"),
|
| 92 |
+
"source": doc.get("source", "unknown"),
|
| 93 |
+
"provider": doc.get("provider", "unknown"),
|
| 94 |
+
"page_number": doc.get("page_number", "unknown"),
|
| 95 |
+
"disease": doc.get("disease", "unknown"),
|
| 96 |
+
"page_type": "CONTEXT PAGE" if is_context_page else "ORIGINAL PAGE",
|
| 97 |
+
"context_enrichment": is_context_page,
|
| 98 |
+
"content": doc.get("content", "")
|
| 99 |
+
}
|
| 100 |
+
cleaned_docs.append(cleaned_doc)
|
| 101 |
+
return cleaned_docs
|
| 102 |
+
|
| 103 |
+
def _create_validation_system_prompt(self) -> str:
|
| 104 |
+
"""Create the system prompt for the validation LLM."""
|
| 105 |
+
return """Role
|
| 106 |
+
|
| 107 |
+
You are medical information validator tasked with validating the following answer to ensure it is accurate, complete, relevant, well-structured (coherent), appropriately concise (length), and properly attributed (cited) based only on the provided documents.
|
| 108 |
+
|
| 109 |
+
Here is your input:
|
| 110 |
+
Question: [User's original question]
|
| 111 |
+
|
| 112 |
+
Retrieved Answer: [The answer generated or retrieved from documents]
|
| 113 |
+
|
| 114 |
+
Documents: [Provide a link or summary of the relevant document sections]
|
| 115 |
+
|
| 116 |
+
Validation Task Criteria:
|
| 117 |
+
|
| 118 |
+
For each criterion below, provide a Score (0-100%) and a detailed Comment explaining the score and noting any necessary improvements, specific issues, or confirming satisfactory performance.
|
| 119 |
+
|
| 120 |
+
Accuracy (0-100%) Is the answer factually correct based only on the provided documents? Ensure that no information contradicts what is written in the documents.
|
| 121 |
+
|
| 122 |
+
If you find any discrepancies or factual errors, point them out in the [Accuracy_Comment].
|
| 123 |
+
|
| 124 |
+
If the answer contains unsupported statements (hallucinations), highlight them in the [Accuracy_Comment].
|
| 125 |
+
|
| 126 |
+
Validation Score Guidelines:
|
| 127 |
+
|
| 128 |
+
100%: The answer is factually correct, with no contradictions or missing information based on the provided documents.
|
| 129 |
+
|
| 130 |
+
85-99%: The answer is mostly correct, but contains minor inaccuracies or omissions that don't substantially affect the overall accuracy.
|
| 131 |
+
|
| 132 |
+
70-84%: The answer contains notable factual errors or omissions that may affect the response's reliability.
|
| 133 |
+
|
| 134 |
+
Below 70%: The answer is factually incorrect, contains critical errors, or misrepresents the content of the documents.
|
| 135 |
+
|
| 136 |
+
Coherence (0-100%) Is the answer logically structured and clear? Ensure the answer flows well, uses appropriate language, and makes sense to a human reader.
|
| 137 |
+
|
| 138 |
+
If the answer is unclear or poorly structured, suggest specific improvements in the [Coherence_Comment].
|
| 139 |
+
|
| 140 |
+
Coherence Score Guidelines:
|
| 141 |
+
|
| 142 |
+
100%: The answer is logically structured, easy to understand, and free from confusion or ambiguity.
|
| 143 |
+
|
| 144 |
+
85-99%: The answer is mostly clear but may have slight issues with flow or readability, such as minor disjointedness.
|
| 145 |
+
|
| 146 |
+
70-84%: The answer lacks clarity or contains some sections that confuse the reader due to poor structure.
|
| 147 |
+
|
| 148 |
+
Below 70%: The answer is poorly structured or difficult to follow, requiring significant improvement in clarity and flow.
|
| 149 |
+
|
| 150 |
+
Relevance (0-100%) Does the answer address the user's question adequately and fully? Ensure that the core topic of the question is covered and that no irrelevant or off-topic information is included.
|
| 151 |
+
|
| 152 |
+
If parts of the question are missed or the answer is irrelevant, identify which parts need improvement in the [Relevance_Comment].
|
| 153 |
+
|
| 154 |
+
Relevance Score Guidelines:
|
| 155 |
+
|
| 156 |
+
100%: The answer directly addresses all parts of the user's question without unnecessary deviations.
|
| 157 |
+
|
| 158 |
+
85-99%: The answer is mostly relevant, but might include slight off-topic information or miss minor aspects of the question.
|
| 159 |
+
|
| 160 |
+
70-84%: The answer misses key points or includes significant irrelevant details that distract from the question.
|
| 161 |
+
|
| 162 |
+
Below 70%: The answer is largely irrelevant to the user's question or includes significant off-topic information.
|
| 163 |
+
|
| 164 |
+
Completeness (0-100%) Does the answer provide all necessary information that is available in the documents to fully address the question? Are there any critical details missing?
|
| 165 |
+
|
| 166 |
+
If the answer is incomplete or vague, suggest what additional details should be included from the documents in the [Completeness_Comment].
|
| 167 |
+
|
| 168 |
+
Completeness Score Guidelines:
|
| 169 |
+
|
| 170 |
+
100%: The answer provides all necessary information in sufficient detail, covering all aspects of the question based on the documents.
|
| 171 |
+
|
| 172 |
+
85-99%: The answer covers most of the required details but may lack some minor points available in the source.
|
| 173 |
+
|
| 174 |
+
70-84%: The answer is missing critical information available in the documents or lacks important details to fully address the question.
|
| 175 |
+
|
| 176 |
+
Below 70%: The answer is severely incomplete, leaving out essential information available in the documents.
|
| 177 |
+
|
| 178 |
+
Citations/Attribution (0-100%) Is every claim in the answer correctly attributed (cited) to the relevant document(s)? Are all citations accurate and correctly placed?
|
| 179 |
+
|
| 180 |
+
If any statement lacks a citation or has an incorrect citation, note the specific issue in the [Citations_Attribution_Comment].
|
| 181 |
+
|
| 182 |
+
Citations/Attribution Score Guidelines:
|
| 183 |
+
|
| 184 |
+
100%: Every piece of information is correctly and appropriately cited to the supporting document(s).
|
| 185 |
+
|
| 186 |
+
85-99%: Citations are mostly correct, but there are one or two minor errors (e.g., misplaced citation, minor formatting issue).
|
| 187 |
+
|
| 188 |
+
70-84%: Several statements are missing citations, or multiple citations are incorrectly attributed, leading to potential confusion about the source.
|
| 189 |
+
|
| 190 |
+
Below 70%: The majority of the answer lacks proper citation, or citations are so poorly done they are unreliable.
|
| 191 |
+
|
| 192 |
+
Length (0-100%) Is the answer the right length to fully answer the question, without being too short (lacking detail) or too long (causing distraction or including irrelevant information)?
|
| 193 |
+
|
| 194 |
+
Provide a rating based on whether the answer strikes the right balance in the [Length_Comment].
|
| 195 |
+
|
| 196 |
+
Length Score Guidelines:
|
| 197 |
+
|
| 198 |
+
100%: The answer is appropriately detailed, offering enough information to fully address the question without unnecessary elaboration.
|
| 199 |
+
|
| 200 |
+
85-99%: The answer is sufficiently detailed but could be slightly more concise or might include minor irrelevant information.
|
| 201 |
+
|
| 202 |
+
70-84%: The answer is either too brief and lacks necessary detail or too lengthy with excessive, distracting information.
|
| 203 |
+
|
| 204 |
+
Below 70%: The answer is either too short to be meaningful or too long, causing distractions or loss of focus.
|
| 205 |
+
|
| 206 |
+
Final Evaluation Output
|
| 207 |
+
|
| 208 |
+
Based on the above checks, provide a rating and a comment for each aspect, and a final overall rating. Your entire output must be a single JSON object that strictly follows the structure defined below.
|
| 209 |
+
|
| 210 |
+
CRITICAL INSTRUCTIONS:
|
| 211 |
+
- Output ONLY valid JSON - no additional text before or after
|
| 212 |
+
- Use double quotes for all strings
|
| 213 |
+
- Ensure all rating values are numbers between 0-100 (no quotes around numbers)
|
| 214 |
+
- Do not include any markdown formatting or code blocks
|
| 215 |
+
- Start your response immediately with { and end with }
|
| 216 |
+
|
| 217 |
+
Required JSON Output Structure:
|
| 218 |
+
|
| 219 |
+
{
|
| 220 |
+
"Accuracy_Rating": "95",
|
| 221 |
+
"Accuracy_Comment": "Detailed comment on factual correctness/issues",
|
| 222 |
+
"Coherence_Rating": "90",
|
| 223 |
+
"Coherence_Comment": "Detailed comment on flow, structure, and clarity",
|
| 224 |
+
"Relevance_Rating": "88",
|
| 225 |
+
"Relevance_Comment": "Detailed comment on addressing the question fully/irrelevant info",
|
| 226 |
+
"Completeness_Rating": "92",
|
| 227 |
+
"Completeness_Comment": "Detailed comment on missing critical details available in the documents",
|
| 228 |
+
"Citations_Attribution_Rating": "85",
|
| 229 |
+
"Citations_Attribution_Comment": "Detailed comment on citation accuracy and completeness",
|
| 230 |
+
"Length_Rating": "90",
|
| 231 |
+
"Length_Comment": "Detailed comment on conciseness and appropriate detail",
|
| 232 |
+
"Overall_Rating": "90",
|
| 233 |
+
"Final_Summary_and_Improvement_Plan": "Overall judgment. If rating is below 90%, describe what specific changes are needed to achieve a 100%. If 90% or above, state that the answer is ready."
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
REMEMBER: Output ONLY the JSON object above with your specific ratings and comments. No other text."""
|
| 237 |
+
|
| 238 |
+
def _create_validator_llm(self) -> ChatOpenAI:
|
| 239 |
+
"""Create a separate LLM instance for validation."""
|
| 240 |
+
try:
|
| 241 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 242 |
+
if not openai_key:
|
| 243 |
+
raise ValueError("OpenAI API key is required for validation")
|
| 244 |
+
return ChatOpenAI(
|
| 245 |
+
model="gpt-4o",
|
| 246 |
+
api_key=openai_key,
|
| 247 |
+
# base_url=os.getenv("OPENAI_BASE_URL"),
|
| 248 |
+
temperature=0.0,
|
| 249 |
+
max_tokens=1024,
|
| 250 |
+
request_timeout=60,
|
| 251 |
+
max_retries=3,
|
| 252 |
+
streaming=False,
|
| 253 |
+
)
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"Failed to create validator LLM: {e}")
|
| 256 |
+
raise
|
| 257 |
+
|
| 258 |
+
def validate_answer(
|
| 259 |
+
self,
|
| 260 |
+
question: str,
|
| 261 |
+
retrieved_documents: List[Dict[str, Any]],
|
| 262 |
+
generated_answer: str
|
| 263 |
+
) -> Dict[str, Any]:
|
| 264 |
+
"""
|
| 265 |
+
Validate a medical answer and return structured evaluation.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
question: The original user question
|
| 269 |
+
retrieved_documents: List of retrieved documents with metadata
|
| 270 |
+
generated_answer: The AI-generated answer to validate
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Dict containing the complete evaluation with metadata
|
| 274 |
+
"""
|
| 275 |
+
try:
|
| 276 |
+
# Generate simple sequential interaction ID
|
| 277 |
+
interaction_id = self._get_next_interaction_id()
|
| 278 |
+
|
| 279 |
+
logger.info(f"Starting validation for interaction {interaction_id}")
|
| 280 |
+
|
| 281 |
+
# Clean documents (remove snippets) for storage
|
| 282 |
+
cleaned_documents = self._clean_documents_for_storage(retrieved_documents)
|
| 283 |
+
|
| 284 |
+
# Format documents for validation
|
| 285 |
+
formatted_docs = self._format_documents_for_validation(retrieved_documents)
|
| 286 |
+
|
| 287 |
+
# Create validation prompt
|
| 288 |
+
validation_prompt = f"""Question: {question}
|
| 289 |
+
|
| 290 |
+
Retrieved Answer: {generated_answer}
|
| 291 |
+
|
| 292 |
+
Documents: {formatted_docs}"""
|
| 293 |
+
|
| 294 |
+
# Get validation from LLM with retry logic
|
| 295 |
+
validation_report = None
|
| 296 |
+
max_retries = 3
|
| 297 |
+
|
| 298 |
+
for attempt in range(max_retries):
|
| 299 |
+
try:
|
| 300 |
+
messages = [
|
| 301 |
+
SystemMessage(content=self.validation_system_prompt),
|
| 302 |
+
HumanMessage(content=validation_prompt)
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
response = self.validator_llm.invoke(messages)
|
| 306 |
+
validation_content = response.content.strip()
|
| 307 |
+
|
| 308 |
+
# Check if response is empty
|
| 309 |
+
if not validation_content:
|
| 310 |
+
logger.warning(f"Empty response from validation LLM (attempt {attempt + 1})")
|
| 311 |
+
if attempt < max_retries - 1:
|
| 312 |
+
continue
|
| 313 |
+
else:
|
| 314 |
+
validation_report = self._create_fallback_validation("Empty response from validation LLM")
|
| 315 |
+
break
|
| 316 |
+
|
| 317 |
+
# Try to parse JSON directly first
|
| 318 |
+
try:
|
| 319 |
+
validation_report = json.loads(validation_content)
|
| 320 |
+
except json.JSONDecodeError:
|
| 321 |
+
# Try to extract JSON from response that might have extra text
|
| 322 |
+
validation_report = self._extract_json_from_response(validation_content)
|
| 323 |
+
if validation_report is None:
|
| 324 |
+
raise json.JSONDecodeError("Could not extract valid JSON", validation_content, 0)
|
| 325 |
+
|
| 326 |
+
# Validate that all required fields are present
|
| 327 |
+
required_fields = [
|
| 328 |
+
"Accuracy_Rating", "Accuracy_Comment",
|
| 329 |
+
"Coherence_Rating", "Coherence_Comment",
|
| 330 |
+
"Relevance_Rating", "Relevance_Comment",
|
| 331 |
+
"Completeness_Rating", "Completeness_Comment",
|
| 332 |
+
"Citations_Attribution_Rating", "Citations_Attribution_Comment",
|
| 333 |
+
"Length_Rating", "Length_Comment",
|
| 334 |
+
"Overall_Rating", "Final_Summary_and_Improvement_Plan"
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
missing_fields = [field for field in required_fields if field not in validation_report]
|
| 338 |
+
if missing_fields:
|
| 339 |
+
logger.warning(f"Missing fields in validation response: {missing_fields}")
|
| 340 |
+
if attempt < max_retries - 1:
|
| 341 |
+
continue
|
| 342 |
+
else:
|
| 343 |
+
# Fill missing fields
|
| 344 |
+
for field in missing_fields:
|
| 345 |
+
if field.endswith("_Rating"):
|
| 346 |
+
validation_report[field] = "0"
|
| 347 |
+
else:
|
| 348 |
+
validation_report[field] = f"Field missing from validation response: {field}"
|
| 349 |
+
|
| 350 |
+
# Success - break out of retry loop
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
except json.JSONDecodeError as e:
|
| 354 |
+
logger.error(f"Failed to parse validation JSON (attempt {attempt + 1}): {e}")
|
| 355 |
+
logger.error(f"Raw response: {validation_content[:200]}...")
|
| 356 |
+
|
| 357 |
+
if attempt < max_retries - 1:
|
| 358 |
+
continue
|
| 359 |
+
else:
|
| 360 |
+
validation_report = self._create_fallback_validation(f"JSON parsing failed after {max_retries} attempts: {str(e)}")
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"Validation LLM error (attempt {attempt + 1}): {e}")
|
| 364 |
+
|
| 365 |
+
if attempt < max_retries - 1:
|
| 366 |
+
continue
|
| 367 |
+
else:
|
| 368 |
+
# Use basic validation as final fallback
|
| 369 |
+
logger.info("Using basic heuristic validation as fallback")
|
| 370 |
+
validation_report = self._create_basic_validation(question, generated_answer, retrieved_documents)
|
| 371 |
+
|
| 372 |
+
# Ensure we have a validation report
|
| 373 |
+
if validation_report is None:
|
| 374 |
+
logger.info("Creating basic validation as final fallback")
|
| 375 |
+
validation_report = self._create_basic_validation(question, generated_answer, retrieved_documents)
|
| 376 |
+
|
| 377 |
+
# Create complete evaluation structure
|
| 378 |
+
evaluation = {
|
| 379 |
+
"interaction_id": interaction_id,
|
| 380 |
+
"timestamp": datetime.now(pytz.timezone('Africa/Cairo')).isoformat(),
|
| 381 |
+
"question": question,
|
| 382 |
+
"retrieved_documents": cleaned_documents,
|
| 383 |
+
"generated_answer": generated_answer,
|
| 384 |
+
"validation_report": validation_report
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
# Save to JSON file
|
| 388 |
+
self._save_evaluation(evaluation)
|
| 389 |
+
|
| 390 |
+
return evaluation
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
logger.error(f"Error during validation: {e}")
|
| 394 |
+
return self._create_error_evaluation(question, retrieved_documents, generated_answer, str(e))
|
| 395 |
+
|
| 396 |
+
def _format_documents_for_validation(self, documents: List[Dict[str, Any]]) -> str:
|
| 397 |
+
"""Format retrieved documents for validation prompt."""
|
| 398 |
+
if not documents:
|
| 399 |
+
return "No documents provided."
|
| 400 |
+
|
| 401 |
+
formatted_docs = []
|
| 402 |
+
for i, doc in enumerate(documents, 1):
|
| 403 |
+
doc_info = f"Document {i}:\n"
|
| 404 |
+
doc_info += f"Source: {doc.get('source', 'Unknown')}\n"
|
| 405 |
+
doc_info += f"Provider: {doc.get('provider', 'Unknown')}\n"
|
| 406 |
+
doc_info += f"Page: {doc.get('page_number', 'Unknown')}\n"
|
| 407 |
+
doc_info += f"Content: {doc.get('snippet', doc.get('content', 'No content'))}\n"
|
| 408 |
+
formatted_docs.append(doc_info)
|
| 409 |
+
|
| 410 |
+
return "\n\n".join(formatted_docs)
|
| 411 |
+
|
| 412 |
+
def _create_fallback_validation(self, error_msg: str) -> Dict[str, str]:
|
| 413 |
+
"""Create a fallback validation report when JSON parsing fails."""
|
| 414 |
+
return {
|
| 415 |
+
"Accuracy_Rating": "0",
|
| 416 |
+
"Accuracy_Comment": f"Validation failed due to parsing error: {error_msg}",
|
| 417 |
+
"Coherence_Rating": "0",
|
| 418 |
+
"Coherence_Comment": "Unable to evaluate due to validation system error",
|
| 419 |
+
"Relevance_Rating": "0",
|
| 420 |
+
"Relevance_Comment": "Unable to evaluate due to validation system error",
|
| 421 |
+
"Completeness_Rating": "0",
|
| 422 |
+
"Completeness_Comment": "Unable to evaluate due to validation system error",
|
| 423 |
+
"Citations_Attribution_Rating": "0",
|
| 424 |
+
"Citations_Attribution_Comment": "Unable to evaluate due to validation system error",
|
| 425 |
+
"Length_Rating": "0",
|
| 426 |
+
"Length_Comment": "Unable to evaluate due to validation system error",
|
| 427 |
+
"Overall_Rating": "0",
|
| 428 |
+
"Final_Summary_and_Improvement_Plan": f"Validation system encountered an error: {error_msg}"
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
def _extract_json_from_response(self, response_text: str) -> Dict[str, str]:
|
| 432 |
+
"""Extract JSON from response that might contain extra text."""
|
| 433 |
+
try:
|
| 434 |
+
# Try to find JSON in the response
|
| 435 |
+
start_idx = response_text.find('{')
|
| 436 |
+
end_idx = response_text.rfind('}')
|
| 437 |
+
|
| 438 |
+
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
| 439 |
+
json_text = response_text[start_idx:end_idx + 1]
|
| 440 |
+
return json.loads(json_text)
|
| 441 |
+
else:
|
| 442 |
+
raise ValueError("No JSON object found in response")
|
| 443 |
+
|
| 444 |
+
except Exception as e:
|
| 445 |
+
logger.error(f"Failed to extract JSON from response: {e}")
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
def _create_basic_validation(self, question: str, answer: str, documents: List[Dict[str, Any]]) -> Dict[str, str]:
|
| 449 |
+
"""Create a basic validation when LLM fails but we can still provide some assessment."""
|
| 450 |
+
|
| 451 |
+
# Basic heuristic scoring
|
| 452 |
+
accuracy_score = "75" # Assume reasonable accuracy if documents are provided
|
| 453 |
+
coherence_score = "80" if len(answer) > 100 and "." in answer else "60"
|
| 454 |
+
relevance_score = "70" if any(word in answer.lower() for word in question.lower().split()) else "50"
|
| 455 |
+
completeness_score = "70" if len(answer) > 200 else "50"
|
| 456 |
+
citations_score = "80" if "Source:" in answer else "30"
|
| 457 |
+
length_score = "75" if 100 < len(answer) < 2000 else "60"
|
| 458 |
+
|
| 459 |
+
# Calculate overall as average
|
| 460 |
+
scores = [int(accuracy_score), int(coherence_score), int(relevance_score),
|
| 461 |
+
int(completeness_score), int(citations_score), int(length_score)]
|
| 462 |
+
overall_score = str(sum(scores) // len(scores))
|
| 463 |
+
|
| 464 |
+
return {
|
| 465 |
+
"Accuracy_Rating": accuracy_score,
|
| 466 |
+
"Accuracy_Comment": "Basic heuristic assessment - LLM validation unavailable. Answer appears to reference provided documents.",
|
| 467 |
+
"Coherence_Rating": coherence_score,
|
| 468 |
+
"Coherence_Comment": "Basic heuristic assessment - Answer structure and length suggest reasonable coherence.",
|
| 469 |
+
"Relevance_Rating": relevance_score,
|
| 470 |
+
"Relevance_Comment": "Basic heuristic assessment - Answer appears to address key terms from the question.",
|
| 471 |
+
"Completeness_Rating": completeness_score,
|
| 472 |
+
"Completeness_Comment": "Basic heuristic assessment - Answer length suggests reasonable completeness.",
|
| 473 |
+
"Citations_Attribution_Rating": citations_score,
|
| 474 |
+
"Citations_Attribution_Comment": "Basic heuristic assessment - Citations detected in answer format." if "Source:" in answer else "Basic heuristic assessment - Limited citation formatting detected.",
|
| 475 |
+
"Length_Rating": length_score,
|
| 476 |
+
"Length_Comment": "Basic heuristic assessment - Answer length appears appropriate for medical question.",
|
| 477 |
+
"Overall_Rating": overall_score,
|
| 478 |
+
"Final_Summary_and_Improvement_Plan": f"Basic validation completed (Overall: {overall_score}/100). LLM-based validation was unavailable, so heuristic scoring was used. For full validation, ensure the validation LLM service is accessible."
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
def _create_error_evaluation(
|
| 482 |
+
self,
|
| 483 |
+
question: str,
|
| 484 |
+
documents: List[Dict[str, Any]],
|
| 485 |
+
answer: str,
|
| 486 |
+
error_msg: str
|
| 487 |
+
) -> Dict[str, Any]:
|
| 488 |
+
"""Create an error evaluation when validation completely fails."""
|
| 489 |
+
return {
|
| 490 |
+
"interaction_id": str(uuid.uuid4()),
|
| 491 |
+
"timestamp": datetime.now(pytz.timezone('Africa/Cairo')).isoformat(),
|
| 492 |
+
"question": question,
|
| 493 |
+
"retrieved_documents": documents,
|
| 494 |
+
"generated_answer": answer,
|
| 495 |
+
"validation_report": {
|
| 496 |
+
"Accuracy_Rating": "0",
|
| 497 |
+
"Accuracy_Comment": f"Validation error: {error_msg}",
|
| 498 |
+
"Coherence_Rating": "0",
|
| 499 |
+
"Coherence_Comment": f"Validation error: {error_msg}",
|
| 500 |
+
"Relevance_Rating": "0",
|
| 501 |
+
"Relevance_Comment": f"Validation error: {error_msg}",
|
| 502 |
+
"Completeness_Rating": "0",
|
| 503 |
+
"Completeness_Comment": f"Validation error: {error_msg}",
|
| 504 |
+
"Citations_Attribution_Rating": "0",
|
| 505 |
+
"Citations_Attribution_Comment": f"Validation error: {error_msg}",
|
| 506 |
+
"Length_Rating": "0",
|
| 507 |
+
"Length_Comment": f"Validation error: {error_msg}",
|
| 508 |
+
"Overall_Rating": "0",
|
| 509 |
+
"Final_Summary_and_Improvement_Plan": f"System error prevented validation: {error_msg}"
|
| 510 |
+
},
|
| 511 |
+
"error": error_msg
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
def _save_evaluation(self, evaluation: Dict[str, Any]) -> None:
|
| 515 |
+
"""Save evaluation to GitHub repository."""
|
| 516 |
+
try:
|
| 517 |
+
logger.info(f"Attempting to save evaluation with ID: {evaluation['interaction_id']}")
|
| 518 |
+
|
| 519 |
+
# Try to save to GitHub first
|
| 520 |
+
github_storage = get_github_storage()
|
| 521 |
+
logger.info("GitHub storage instance obtained, calling save_validation_results...")
|
| 522 |
+
success = github_storage.save_validation_results(evaluation)
|
| 523 |
+
|
| 524 |
+
if success:
|
| 525 |
+
logger.info(f"✓ Evaluation saved to GitHub successfully with ID: {evaluation['interaction_id']}")
|
| 526 |
+
else:
|
| 527 |
+
logger.warning(f"GitHub save failed for evaluation {evaluation['interaction_id']}, falling back to local storage")
|
| 528 |
+
# Fallback to local storage if GitHub fails
|
| 529 |
+
evaluations = []
|
| 530 |
+
if os.path.exists(self.evaluation_file):
|
| 531 |
+
try:
|
| 532 |
+
with open(self.evaluation_file, 'r', encoding='utf-8') as f:
|
| 533 |
+
evaluations = json.load(f)
|
| 534 |
+
logger.info(f"Loaded {len(evaluations)} existing evaluations from local file")
|
| 535 |
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
| 536 |
+
logger.warning(f"Could not load local file: {e}")
|
| 537 |
+
evaluations = []
|
| 538 |
+
|
| 539 |
+
# Add new evaluation
|
| 540 |
+
evaluations.append(evaluation)
|
| 541 |
+
|
| 542 |
+
# Save back to local file
|
| 543 |
+
with open(self.evaluation_file, 'w', encoding='utf-8') as f:
|
| 544 |
+
json.dump(evaluations, f, indent=2, ensure_ascii=False)
|
| 545 |
+
|
| 546 |
+
logger.info(f"✓ Evaluation saved locally (GitHub failed) with ID: {evaluation['interaction_id']}")
|
| 547 |
+
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logger.error(f"Failed to save evaluation: {e}")
|
| 550 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 551 |
+
|
| 552 |
+
def get_evaluation_summary(self, limit: int = 10) -> Dict[str, Any]:
|
| 553 |
+
"""Get summary of recent evaluations from GitHub repository."""
|
| 554 |
+
try:
|
| 555 |
+
# Try to get data from GitHub first
|
| 556 |
+
github_storage = get_github_storage()
|
| 557 |
+
github_results = github_storage.get_validation_results(limit)
|
| 558 |
+
|
| 559 |
+
if github_results and "error" not in github_results:
|
| 560 |
+
return github_results
|
| 561 |
+
|
| 562 |
+
# Fallback to local file if GitHub fails
|
| 563 |
+
if not os.path.exists(self.evaluation_file):
|
| 564 |
+
return {"message": "No evaluations found", "evaluations": []}
|
| 565 |
+
|
| 566 |
+
with open(self.evaluation_file, 'r', encoding='utf-8') as f:
|
| 567 |
+
evaluations = json.load(f)
|
| 568 |
+
|
| 569 |
+
# Get recent evaluations
|
| 570 |
+
recent_evaluations = evaluations[-limit:] if evaluations else []
|
| 571 |
+
|
| 572 |
+
# Calculate average scores
|
| 573 |
+
if recent_evaluations:
|
| 574 |
+
total_scores = {
|
| 575 |
+
"accuracy": 0,
|
| 576 |
+
"coherence": 0,
|
| 577 |
+
"relevance": 0,
|
| 578 |
+
"completeness": 0,
|
| 579 |
+
"citations": 0,
|
| 580 |
+
"length": 0,
|
| 581 |
+
"overall": 0
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
count = len(recent_evaluations)
|
| 585 |
+
for eval_data in recent_evaluations:
|
| 586 |
+
report = eval_data.get("validation_report", {})
|
| 587 |
+
total_scores["accuracy"] += int(report.get("Accuracy_Rating", 0))
|
| 588 |
+
total_scores["coherence"] += int(report.get("Coherence_Rating", 0))
|
| 589 |
+
total_scores["relevance"] += int(report.get("Relevance_Rating", 0))
|
| 590 |
+
total_scores["completeness"] += int(report.get("Completeness_Rating", 0))
|
| 591 |
+
total_scores["citations"] += int(report.get("Citations_Attribution_Rating", 0))
|
| 592 |
+
total_scores["length"] += int(report.get("Length_Rating", 0))
|
| 593 |
+
total_scores["overall"] += int(report.get("Overall_Rating", 0))
|
| 594 |
+
|
| 595 |
+
averages = {key: round(value / count, 1) for key, value in total_scores.items()}
|
| 596 |
+
else:
|
| 597 |
+
averages = {}
|
| 598 |
+
|
| 599 |
+
return {
|
| 600 |
+
"total_evaluations": len(evaluations),
|
| 601 |
+
"recent_count": len(recent_evaluations),
|
| 602 |
+
"average_scores": averages,
|
| 603 |
+
"evaluations": recent_evaluations
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
except Exception as e:
|
| 607 |
+
logger.error(f"Failed to get evaluation summary: {e}")
|
| 608 |
+
return {"error": str(e), "evaluations": []}
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Global validator instance
|
| 612 |
+
_validator = None
|
| 613 |
+
|
| 614 |
+
def get_validator() -> MedicalAnswerValidator:
|
| 615 |
+
"""Get the global validator instance with lazy loading."""
|
| 616 |
+
global _validator
|
| 617 |
+
if _validator is None:
|
| 618 |
+
_validator = MedicalAnswerValidator()
|
| 619 |
+
return _validator
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def validate_medical_answer(
|
| 623 |
+
question: str,
|
| 624 |
+
retrieved_documents: List[Dict[str, Any]],
|
| 625 |
+
generated_answer: str
|
| 626 |
+
) -> Dict[str, Any]:
|
| 627 |
+
"""
|
| 628 |
+
Convenience function to validate a medical answer.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
question: The original user question
|
| 632 |
+
retrieved_documents: List of retrieved documents with metadata
|
| 633 |
+
generated_answer: The AI-generated answer to validate
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
Dict containing the complete evaluation with metadata
|
| 637 |
+
"""
|
| 638 |
+
validator = get_validator()
|
| 639 |
+
return validator.validate_answer(question, retrieved_documents, generated_answer)
|
core/vector_store.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import utils
|
| 2 |
+
from .config import logger
|
| 3 |
+
|
| 4 |
+
def create_vector_store():
|
| 5 |
+
"""Create a new vector store from documents in NEW_DATA directory."""
|
| 6 |
+
try:
|
| 7 |
+
documents = utils.create_documents()
|
| 8 |
+
chunks = utils.split_documents(documents)
|
| 9 |
+
vector_store = utils.create_vector_store(chunks)
|
| 10 |
+
logger.info("Vector store created successfully")
|
| 11 |
+
except Exception as e:
|
| 12 |
+
logger.error(f"Error creating vector store: {str(e)}")
|
| 13 |
+
raise
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_vector_store() -> Optional[FAISS]:
|
| 17 |
+
"""Load existing vector store with proper error handling"""
|
| 18 |
+
try:
|
| 19 |
+
if Path(VECTOR_STORE_DIR).exists():
|
| 20 |
+
vector_store = FAISS.load_local(
|
| 21 |
+
str(VECTOR_STORE_DIR),
|
| 22 |
+
EMBEDDING_MODEL,
|
| 23 |
+
allow_dangerous_deserialization=True
|
| 24 |
+
)
|
| 25 |
+
logger.info("Successfully loaded existing vector store")
|
| 26 |
+
return vector_store
|
| 27 |
+
else:
|
| 28 |
+
logger.info("No existing vector store found")
|
| 29 |
+
logger.info("Creating new vector store...")
|
| 30 |
+
create_vector_store()
|
| 31 |
+
return None
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Failed to load vector store: {e}")
|
| 34 |
+
return None
|
data/chunks.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f23c9d2504e67eeddfe16adbdd54d39cc89eba79731e9e9908ebd448ce565e9c
|
| 3 |
+
size 78626
|
data/vector_store/index.faiss
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74dda4934f1b39740cf93e1737e56709468173e46923b0e8ef8e985ddf626e8f
|
| 3 |
+
size 233517
|
data/vector_store/index.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d650c0f8358ebccd9dbacd18c98bd0ca5794efbfbb3fec2f182b3e9eacb6b2de
|
| 3 |
+
size 82475
|
example_patient_input.json
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"description": "Example HBV patient data for assessment API",
|
| 3 |
+
"examples": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Eligible Patient - Moderate Fibrosis",
|
| 6 |
+
"data": {
|
| 7 |
+
"sex": "Male",
|
| 8 |
+
"age": 45,
|
| 9 |
+
"pregnancy_status": "Not pregnant",
|
| 10 |
+
"hbsag_status": "Positive",
|
| 11 |
+
"duration_hbsag_months": 12,
|
| 12 |
+
"hbv_dna_level": 50000,
|
| 13 |
+
"hbeag_status": "Positive",
|
| 14 |
+
"alt_level": 60,
|
| 15 |
+
"fibrosis_stage": "F2-F3",
|
| 16 |
+
"necroinflammatory_activity": "A2",
|
| 17 |
+
"extrahepatic_manifestations": false,
|
| 18 |
+
"immunosuppression_status": "None",
|
| 19 |
+
"coinfections": [],
|
| 20 |
+
"family_history_cirrhosis_hcc": false,
|
| 21 |
+
"other_comorbidities": []
|
| 22 |
+
},
|
| 23 |
+
"expected_result": "Eligible - meets SASLT 2021 criteria (HBV DNA > 2,000, ALT > ULN, moderate fibrosis)"
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "Eligible Patient - Cirrhosis",
|
| 27 |
+
"data": {
|
| 28 |
+
"sex": "Male",
|
| 29 |
+
"age": 55,
|
| 30 |
+
"pregnancy_status": "Not pregnant",
|
| 31 |
+
"hbsag_status": "Positive",
|
| 32 |
+
"duration_hbsag_months": 120,
|
| 33 |
+
"hbv_dna_level": 500,
|
| 34 |
+
"hbeag_status": "Negative",
|
| 35 |
+
"alt_level": 30,
|
| 36 |
+
"fibrosis_stage": "F4",
|
| 37 |
+
"necroinflammatory_activity": "A1",
|
| 38 |
+
"extrahepatic_manifestations": false,
|
| 39 |
+
"immunosuppression_status": "None",
|
| 40 |
+
"coinfections": [],
|
| 41 |
+
"family_history_cirrhosis_hcc": true,
|
| 42 |
+
"other_comorbidities": ["Diabetes"]
|
| 43 |
+
},
|
| 44 |
+
"expected_result": "Eligible - cirrhosis (F4) with detectable HBV DNA"
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"name": "Not Eligible Patient - Low HBV DNA",
|
| 48 |
+
"data": {
|
| 49 |
+
"sex": "Female",
|
| 50 |
+
"age": 35,
|
| 51 |
+
"pregnancy_status": "Not pregnant",
|
| 52 |
+
"hbsag_status": "Positive",
|
| 53 |
+
"duration_hbsag_months": 8,
|
| 54 |
+
"hbv_dna_level": 1500,
|
| 55 |
+
"hbeag_status": "Negative",
|
| 56 |
+
"alt_level": 20,
|
| 57 |
+
"fibrosis_stage": "F0-F1",
|
| 58 |
+
"necroinflammatory_activity": "A0",
|
| 59 |
+
"extrahepatic_manifestations": false,
|
| 60 |
+
"immunosuppression_status": "None",
|
| 61 |
+
"coinfections": [],
|
| 62 |
+
"family_history_cirrhosis_hcc": false,
|
| 63 |
+
"other_comorbidities": []
|
| 64 |
+
},
|
| 65 |
+
"expected_result": "Not Eligible - does not meet SASLT 2021 criteria"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"name": "Pregnant Patient - High HBV DNA",
|
| 69 |
+
"data": {
|
| 70 |
+
"sex": "Female",
|
| 71 |
+
"age": 28,
|
| 72 |
+
"pregnancy_status": "Pregnant",
|
| 73 |
+
"hbsag_status": "Positive",
|
| 74 |
+
"duration_hbsag_months": 24,
|
| 75 |
+
"hbv_dna_level": 150000,
|
| 76 |
+
"hbeag_status": "Positive",
|
| 77 |
+
"alt_level": 40,
|
| 78 |
+
"fibrosis_stage": "F0-F1",
|
| 79 |
+
"necroinflammatory_activity": "A1",
|
| 80 |
+
"extrahepatic_manifestations": false,
|
| 81 |
+
"immunosuppression_status": "None",
|
| 82 |
+
"coinfections": [],
|
| 83 |
+
"family_history_cirrhosis_hcc": false,
|
| 84 |
+
"other_comorbidities": []
|
| 85 |
+
},
|
| 86 |
+
"expected_result": "Eligible - HBV DNA > 2,000, ALT > ULN. Note: Pregnant with HBV DNA > 100,000 requires prophylaxis"
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"name": "Patient with Extrahepatic Manifestations",
|
| 90 |
+
"data": {
|
| 91 |
+
"sex": "Male",
|
| 92 |
+
"age": 42,
|
| 93 |
+
"pregnancy_status": "Not pregnant",
|
| 94 |
+
"hbsag_status": "Positive",
|
| 95 |
+
"duration_hbsag_months": 36,
|
| 96 |
+
"hbv_dna_level": 5000,
|
| 97 |
+
"hbeag_status": "Negative",
|
| 98 |
+
"alt_level": 28,
|
| 99 |
+
"fibrosis_stage": "F0-F1",
|
| 100 |
+
"necroinflammatory_activity": "A0",
|
| 101 |
+
"extrahepatic_manifestations": true,
|
| 102 |
+
"immunosuppression_status": "None",
|
| 103 |
+
"coinfections": [],
|
| 104 |
+
"family_history_cirrhosis_hcc": false,
|
| 105 |
+
"other_comorbidities": ["Polyarteritis nodosa"]
|
| 106 |
+
},
|
| 107 |
+
"expected_result": "Eligible - extrahepatic manifestations (Grade D)"
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"name": "Patient with HIV Coinfection",
|
| 111 |
+
"data": {
|
| 112 |
+
"sex": "Male",
|
| 113 |
+
"age": 38,
|
| 114 |
+
"pregnancy_status": "Not pregnant",
|
| 115 |
+
"hbsag_status": "Positive",
|
| 116 |
+
"duration_hbsag_months": 60,
|
| 117 |
+
"hbv_dna_level": 25000,
|
| 118 |
+
"hbeag_status": "Positive",
|
| 119 |
+
"alt_level": 80,
|
| 120 |
+
"fibrosis_stage": "F2-F3",
|
| 121 |
+
"necroinflammatory_activity": "A2",
|
| 122 |
+
"extrahepatic_manifestations": false,
|
| 123 |
+
"immunosuppression_status": "None",
|
| 124 |
+
"coinfections": ["HIV"],
|
| 125 |
+
"family_history_cirrhosis_hcc": false,
|
| 126 |
+
"other_comorbidities": []
|
| 127 |
+
},
|
| 128 |
+
"expected_result": "Eligible - meets multiple criteria. Note: HIV coinfection requires specialized management"
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"name": "HBeAg-positive Chronic Infection (Age > 30)",
|
| 132 |
+
"data": {
|
| 133 |
+
"sex": "Female",
|
| 134 |
+
"age": 35,
|
| 135 |
+
"pregnancy_status": "Not pregnant",
|
| 136 |
+
"hbsag_status": "Positive",
|
| 137 |
+
"duration_hbsag_months": 180,
|
| 138 |
+
"hbv_dna_level": 1000000,
|
| 139 |
+
"hbeag_status": "Positive",
|
| 140 |
+
"alt_level": 22,
|
| 141 |
+
"fibrosis_stage": "F0-F1",
|
| 142 |
+
"necroinflammatory_activity": "A0",
|
| 143 |
+
"extrahepatic_manifestations": false,
|
| 144 |
+
"immunosuppression_status": "None",
|
| 145 |
+
"coinfections": [],
|
| 146 |
+
"family_history_cirrhosis_hcc": false,
|
| 147 |
+
"other_comorbidities": []
|
| 148 |
+
},
|
| 149 |
+
"expected_result": "Eligible - HBeAg-positive chronic infection, age > 30 (Grade D)"
|
| 150 |
+
}
|
| 151 |
+
],
|
| 152 |
+
"field_descriptions": {
|
| 153 |
+
"sex": "Male or Female",
|
| 154 |
+
"age": "Patient age in years (0-120)",
|
| 155 |
+
"pregnancy_status": "Not pregnant or Pregnant",
|
| 156 |
+
"hbsag_status": "Positive or Negative",
|
| 157 |
+
"duration_hbsag_months": "Duration of HBsAg positivity in months (≥ 6 months required)",
|
| 158 |
+
"hbv_dna_level": "HBV DNA level in IU/mL",
|
| 159 |
+
"hbeag_status": "Positive or Negative",
|
| 160 |
+
"alt_level": "ALT level in U/L (ULN: Men ≤ 35, Women ≤ 25)",
|
| 161 |
+
"fibrosis_stage": "F0-F1 (minimal), F2-F3 (moderate), or F4 (cirrhosis)",
|
| 162 |
+
"necroinflammatory_activity": "A0 (none), A1 (mild), A2 (moderate), or A3 (severe)",
|
| 163 |
+
"extrahepatic_manifestations": "true or false",
|
| 164 |
+
"immunosuppression_status": "None, Chemotherapy, or Other",
|
| 165 |
+
"coinfections": "Array of: HIV, HCV, HDV",
|
| 166 |
+
"family_history_cirrhosis_hcc": "true or false (first-degree relative)",
|
| 167 |
+
"other_comorbidities": "Array of comorbidity names"
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
export_prompts.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to export all prompts used in the HBV AI Assistant project to a Word document.
|
| 3 |
+
Includes: Agent System Prompt, Validation Prompt, and HBV Assessment Prompt.
|
| 4 |
+
"""
|
| 5 |
+
from docx import Document
|
| 6 |
+
from docx.shared import Pt, RGBColor, Inches
|
| 7 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
# Add the project root to the path
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def extract_system_message_from_agent():
|
| 17 |
+
"""Extract SYSTEM_MESSAGE from core/agent.py without importing it."""
|
| 18 |
+
agent_path = os.path.join(os.path.dirname(__file__), 'core', 'agent.py')
|
| 19 |
+
with open(agent_path, 'r', encoding='utf-8') as f:
|
| 20 |
+
content = f.read()
|
| 21 |
+
|
| 22 |
+
# Extract SYSTEM_MESSAGE using regex
|
| 23 |
+
match = re.search(r'SYSTEM_MESSAGE = """(.*?)"""', content, re.DOTALL)
|
| 24 |
+
if match:
|
| 25 |
+
return match.group(1).strip()
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError("Could not extract SYSTEM_MESSAGE from core/agent.py")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def extract_validation_prompt():
|
| 31 |
+
"""Extract validation prompt from core/validation.py without importing it."""
|
| 32 |
+
validation_path = os.path.join(os.path.dirname(__file__), 'core', 'validation.py')
|
| 33 |
+
with open(validation_path, 'r', encoding='utf-8') as f:
|
| 34 |
+
content = f.read()
|
| 35 |
+
|
| 36 |
+
# Find the _create_validation_system_prompt method and extract the return string
|
| 37 |
+
match = re.search(r'def _create_validation_system_prompt\(self\) -> str:.*?return """(.*?)"""', content, re.DOTALL)
|
| 38 |
+
if match:
|
| 39 |
+
return match.group(1).strip()
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError("Could not extract validation prompt from core/validation.py")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_hbv_assessment_prompt():
|
| 45 |
+
"""Extract HBV assessment prompt from core/hbv_assessment.py without importing it."""
|
| 46 |
+
assessment_path = os.path.join(os.path.dirname(__file__), 'core', 'hbv_assessment.py')
|
| 47 |
+
with open(assessment_path, 'r', encoding='utf-8') as f:
|
| 48 |
+
content = f.read()
|
| 49 |
+
|
| 50 |
+
# Find the analysis_prompt in assess_hbv_eligibility function
|
| 51 |
+
match = re.search(r'analysis_prompt = f"""(.*?)"""', content, re.DOTALL)
|
| 52 |
+
if match:
|
| 53 |
+
# Extract the prompt template (without f-string variables)
|
| 54 |
+
prompt_template = match.group(1).strip()
|
| 55 |
+
return prompt_template
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError("Could not extract HBV assessment prompt from core/hbv_assessment.py")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def create_prompts_document():
|
| 61 |
+
"""Create a Word document with both prompts."""
|
| 62 |
+
|
| 63 |
+
# Create a new Document
|
| 64 |
+
doc = Document()
|
| 65 |
+
|
| 66 |
+
# Add title
|
| 67 |
+
title = doc.add_heading('HBV AI Assistant - System Prompts', 0)
|
| 68 |
+
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 69 |
+
|
| 70 |
+
# Add metadata
|
| 71 |
+
metadata = doc.add_paragraph()
|
| 72 |
+
metadata.add_run('Project: HBV Clinical Decision Support System\n').bold = True
|
| 73 |
+
metadata.add_run('Generated: October 30, 2025\n')
|
| 74 |
+
metadata.add_run('Description: This document contains all system prompts used by the HBV AI Assistant, including the main agent prompt, validation prompt, and HBV assessment prompt.')
|
| 75 |
+
|
| 76 |
+
doc.add_page_break()
|
| 77 |
+
|
| 78 |
+
# ==================== HBV ASSESSMENT PROMPT ====================
|
| 79 |
+
doc.add_heading('1. HBV Assessment System Prompt', 1)
|
| 80 |
+
|
| 81 |
+
# Add description
|
| 82 |
+
desc3 = doc.add_paragraph()
|
| 83 |
+
desc3.add_run('Purpose: ').bold = True
|
| 84 |
+
desc3.add_run('Evaluates patient eligibility for HBV treatment based on SASLT 2021 guidelines\n')
|
| 85 |
+
desc3.add_run('Location: ').bold = True
|
| 86 |
+
desc3.add_run('core/hbv_assessment.py\n')
|
| 87 |
+
desc3.add_run('Model: ').bold = True
|
| 88 |
+
desc3.add_run('GPT-4 (configurable)\n')
|
| 89 |
+
desc3.add_run('Function: ').bold = True
|
| 90 |
+
desc3.add_run('assess_hbv_eligibility()')
|
| 91 |
+
|
| 92 |
+
doc.add_paragraph() # Spacing
|
| 93 |
+
|
| 94 |
+
# Add assessment overview
|
| 95 |
+
doc.add_heading('Assessment Process:', 2)
|
| 96 |
+
process = doc.add_paragraph()
|
| 97 |
+
process.add_run('The HBV assessment prompt analyzes patient data against SASLT 2021 guidelines:\n')
|
| 98 |
+
process_list = [
|
| 99 |
+
'Evaluates patient parameters (HBV DNA, ALT, fibrosis stage, etc.)',
|
| 100 |
+
'Compares against treatment eligibility criteria from SASLT 2021',
|
| 101 |
+
'Determines eligibility status (Eligible/Not Eligible/Borderline)',
|
| 102 |
+
'Recommends first-line antiviral agents (ETV, TDF, TAF) if eligible',
|
| 103 |
+
'Provides comprehensive assessment with inline citations',
|
| 104 |
+
'Includes special considerations (pregnancy, immunosuppression, coinfections)'
|
| 105 |
+
]
|
| 106 |
+
for item in process_list:
|
| 107 |
+
doc.add_paragraph(item, style='List Bullet')
|
| 108 |
+
|
| 109 |
+
doc.add_paragraph() # Spacing
|
| 110 |
+
|
| 111 |
+
# Add the actual HBV assessment prompt
|
| 112 |
+
doc.add_heading('Prompt Content:', 2)
|
| 113 |
+
hbv_assessment_prompt = extract_hbv_assessment_prompt()
|
| 114 |
+
|
| 115 |
+
assessment_para = doc.add_paragraph(hbv_assessment_prompt)
|
| 116 |
+
assessment_para.style = 'Normal'
|
| 117 |
+
|
| 118 |
+
# Format the assessment prompt text
|
| 119 |
+
for run in assessment_para.runs:
|
| 120 |
+
run.font.size = Pt(10)
|
| 121 |
+
run.font.name = 'Courier New'
|
| 122 |
+
|
| 123 |
+
doc.add_page_break()
|
| 124 |
+
|
| 125 |
+
# ==================== AGENT PROMPT ====================
|
| 126 |
+
doc.add_heading('2. Agent System Prompt', 1)
|
| 127 |
+
|
| 128 |
+
# Add description
|
| 129 |
+
desc1 = doc.add_paragraph()
|
| 130 |
+
desc1.add_run('Purpose: ').bold = True
|
| 131 |
+
desc1.add_run('Main conversational AI agent for clinical decision support\n')
|
| 132 |
+
desc1.add_run('Location: ').bold = True
|
| 133 |
+
desc1.add_run('core/agent.py\n')
|
| 134 |
+
desc1.add_run('Model: ').bold = True
|
| 135 |
+
desc1.add_run('GPT-4 (configurable)')
|
| 136 |
+
|
| 137 |
+
doc.add_paragraph() # Spacing
|
| 138 |
+
|
| 139 |
+
# Add the actual prompt
|
| 140 |
+
doc.add_heading('Prompt Content:', 2)
|
| 141 |
+
system_message = extract_system_message_from_agent()
|
| 142 |
+
prompt_para = doc.add_paragraph(system_message)
|
| 143 |
+
prompt_para.style = 'Normal'
|
| 144 |
+
|
| 145 |
+
# Format the prompt text
|
| 146 |
+
for run in prompt_para.runs:
|
| 147 |
+
run.font.size = Pt(10)
|
| 148 |
+
run.font.name = 'Courier New'
|
| 149 |
+
|
| 150 |
+
doc.add_page_break()
|
| 151 |
+
|
| 152 |
+
# ==================== VALIDATION PROMPT ====================
|
| 153 |
+
doc.add_heading('3. Validation System Prompt', 1)
|
| 154 |
+
|
| 155 |
+
# Add description
|
| 156 |
+
desc2 = doc.add_paragraph()
|
| 157 |
+
desc2.add_run('Purpose: ').bold = True
|
| 158 |
+
desc2.add_run('Validates generated medical answers for quality assurance\n')
|
| 159 |
+
desc2.add_run('Location: ').bold = True
|
| 160 |
+
desc2.add_run('core/validation.py\n')
|
| 161 |
+
desc2.add_run('Model: ').bold = True
|
| 162 |
+
desc2.add_run('GPT-4o')
|
| 163 |
+
|
| 164 |
+
doc.add_paragraph() # Spacing
|
| 165 |
+
|
| 166 |
+
# Add validation criteria overview
|
| 167 |
+
doc.add_heading('Validation Criteria:', 2)
|
| 168 |
+
criteria = doc.add_paragraph()
|
| 169 |
+
criteria.add_run('The validation prompt evaluates answers on 6 dimensions:\n')
|
| 170 |
+
criteria_list = [
|
| 171 |
+
'Accuracy (0-100%): Factual correctness based on provided documents',
|
| 172 |
+
'Coherence (0-100%): Logical structure, clarity, and readability',
|
| 173 |
+
'Relevance (0-100%): Addresses user\'s question without off-topic information',
|
| 174 |
+
'Completeness (0-100%): Includes all necessary information from documents',
|
| 175 |
+
'Citations/Attribution (0-100%): Proper citation of all claims',
|
| 176 |
+
'Length (0-100%): Appropriate detail without being too brief or verbose'
|
| 177 |
+
]
|
| 178 |
+
for criterion in criteria_list:
|
| 179 |
+
doc.add_paragraph(criterion, style='List Bullet')
|
| 180 |
+
|
| 181 |
+
doc.add_paragraph() # Spacing
|
| 182 |
+
|
| 183 |
+
# Add the actual validation prompt
|
| 184 |
+
doc.add_heading('Prompt Content:', 2)
|
| 185 |
+
validation_prompt = extract_validation_prompt()
|
| 186 |
+
|
| 187 |
+
validation_para = doc.add_paragraph(validation_prompt)
|
| 188 |
+
validation_para.style = 'Normal'
|
| 189 |
+
|
| 190 |
+
# Format the validation prompt text
|
| 191 |
+
for run in validation_para.runs:
|
| 192 |
+
run.font.size = Pt(10)
|
| 193 |
+
run.font.name = 'Courier New'
|
| 194 |
+
|
| 195 |
+
# Save the document
|
| 196 |
+
output_path = 'HBV_AI_Assistant_System_Prompts.docx'
|
| 197 |
+
doc.save(output_path)
|
| 198 |
+
print(f"✓ Document created successfully: {output_path}")
|
| 199 |
+
print(f"✓ File location: {os.path.abspath(output_path)}")
|
| 200 |
+
print(f"\n📋 Exported Prompts:")
|
| 201 |
+
print(f" 1. HBV Assessment Prompt (core/hbv_assessment.py)")
|
| 202 |
+
print(f" 2. Agent System Prompt (core/agent.py)")
|
| 203 |
+
print(f" 3. Validation System Prompt (core/validation.py)")
|
| 204 |
+
|
| 205 |
+
return output_path
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
try:
|
| 210 |
+
create_prompts_document()
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(f"✗ Error creating document: {str(e)}")
|
| 213 |
+
import traceback
|
| 214 |
+
traceback.print_exc()
|
requirements.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API
|
| 2 |
+
fastapi==0.116.1
|
| 3 |
+
uvicorn==0.35.0
|
| 4 |
+
|
| 5 |
+
# LangChain ecosystem - only what's needed
|
| 6 |
+
langchain==0.3.27
|
| 7 |
+
langchain_community==0.3.27
|
| 8 |
+
langchain_openai==0.3.28
|
| 9 |
+
langchain_huggingface==0.3.1
|
| 10 |
+
langchain-pymupdf4llm==0.4.1
|
| 11 |
+
|
| 12 |
+
# OpenAI
|
| 13 |
+
openai==1.99.2
|
| 14 |
+
|
| 15 |
+
# Data processing
|
| 16 |
+
pandas==2.3.1
|
| 17 |
+
pydantic==2.11.7
|
| 18 |
+
|
| 19 |
+
# Utilities
|
| 20 |
+
python-dotenv==1.1.1
|
| 21 |
+
pytz==2025.2
|
| 22 |
+
requests==2.32.4
|
| 23 |
+
|
| 24 |
+
# Embeddings and search - CPU optimized versions
|
| 25 |
+
sentence-transformers==5.0.0
|
| 26 |
+
faiss-cpu==1.11.0.post1 # CPU version is much smaller
|
| 27 |
+
rank-bm25==0.2.2
|
| 28 |
+
|
| 29 |
+
# Fix for numpy and torch compatibility
|
| 30 |
+
numpy<2
|
| 31 |
+
torch==2.2.2+cpu
|
| 32 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 33 |
+
|
| 34 |
+
# Document generation
|
| 35 |
+
python-docx==1.1.2
|
| 36 |
+
reportlab==4.2.5
|
| 37 |
+
|
| 38 |
+
# Authentication
|
| 39 |
+
python-multipart>=0.0.18
|
| 40 |
+
itsdangerous==2.2.0
|
| 41 |
+
|
tempCodeRunnerFile.python
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import required libraries
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
from langchain.schema import Document
|
| 6 |
+
from core.config import logger
|
| 7 |
+
from unstructured.partition.pdf import partition_pdf
|
| 8 |
+
from unstructured.chunking.title import chunk_by_title
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_pdf_documents(pdf_path: Path) -> List[Document]:
|
| 12 |
+
"""
|
| 13 |
+
Load and process PDF documents from medical guidelines using Unstructured.io.
|
| 14 |
+
Uses high-resolution strategy with ML-based table detection for borderless tables.
|
| 15 |
+
Extracts disease and provider from directory structure.
|
| 16 |
+
|
| 17 |
+
Directory structure expected: data/new_data/PROVIDER/file.pdf
|
| 18 |
+
Example: data/new_data/SASLT/SASLT_2021.pdf
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
pdf_path: Path to the PDF file
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
List of Document objects with metadata (source, disease, provider, page_number)
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
# Validate file exists
|
| 28 |
+
if not pdf_path.exists():
|
| 29 |
+
raise FileNotFoundError(f"PDF file not found at {pdf_path}")
|
| 30 |
+
|
| 31 |
+
# Extract provider from directory structure
|
| 32 |
+
# Structure: data/new_data/PROVIDER/file.pdf
|
| 33 |
+
path_parts = pdf_path.parts
|
| 34 |
+
disease = "HBV" # Default disease for this system
|
| 35 |
+
provider = "unknown"
|
| 36 |
+
|
| 37 |
+
# Find provider: it's the parent directory of the PDF file
|
| 38 |
+
if len(path_parts) >= 2:
|
| 39 |
+
provider = path_parts[-2] # Parent directory (e.g., SASLT)
|
| 40 |
+
|
| 41 |
+
# If provider is 'new_data', it means file is directly in new_data folder
|
| 42 |
+
if provider.lower() == "new_data":
|
| 43 |
+
provider = "unknown"
|
| 44 |
+
|
| 45 |
+
# Use Unstructured.io to partition the PDF
|
| 46 |
+
# hi_res strategy uses ML models for better table detection
|
| 47 |
+
elements = partition_pdf(
|
| 48 |
+
filename=str(pdf_path),
|
| 49 |
+
strategy="hi_res", # Use ML-based detection for borderless tables
|
| 50 |
+
infer_table_structure=True, # Detect table structure without borders
|
| 51 |
+
extract_images_in_pdf=True, # Extract images with OCR
|
| 52 |
+
languages=["eng"], # OCR language
|
| 53 |
+
include_page_breaks=True, # Maintain page boundaries
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Group elements by page number
|
| 57 |
+
pages_content = {}
|
| 58 |
+
for element in elements:
|
| 59 |
+
# Get page number from metadata (1-indexed)
|
| 60 |
+
page_num = element.metadata.page_number if hasattr(element.metadata, 'page_number') else 1
|
| 61 |
+
|
| 62 |
+
if page_num not in pages_content:
|
| 63 |
+
pages_content[page_num] = []
|
| 64 |
+
|
| 65 |
+
# Convert element to text
|
| 66 |
+
pages_content[page_num].append(element.text)
|
| 67 |
+
|
| 68 |
+
# Create Document objects for each page
|
| 69 |
+
documents = []
|
| 70 |
+
for page_num in sorted(pages_content.keys()):
|
| 71 |
+
# Combine all elements on the page
|
| 72 |
+
page_content = "\n\n".join(pages_content[page_num])
|
| 73 |
+
|
| 74 |
+
if page_content.strip():
|
| 75 |
+
processed_doc = Document(
|
| 76 |
+
page_content=page_content,
|
| 77 |
+
metadata={
|
| 78 |
+
"source": pdf_path.name,
|
| 79 |
+
"disease": disease,
|
| 80 |
+
"provider": provider,
|
| 81 |
+
"page_number": page_num
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
documents.append(processed_doc)
|
| 85 |
+
|
| 86 |
+
logger.info(f"Loaded {len(documents)} pages from PDF: {pdf_path.name} (Disease: {disease}, Provider: {provider})")
|
| 87 |
+
return documents
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Error loading PDF documents from {pdf_path}: {str(e)}")
|
| 91 |
+
raise
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Alternative version: Preserve element types (useful for RAG)
|
| 95 |
+
def load_pdf_documents_with_elements(pdf_path: Path) -> List[Document]:
|
| 96 |
+
"""
|
| 97 |
+
Load PDF documents while preserving element types (text, table, title, etc.).
|
| 98 |
+
Useful for better RAG retrieval by maintaining document structure.
|
| 99 |
+
"""
|
| 100 |
+
try:
|
| 101 |
+
if not pdf_path.exists():
|
| 102 |
+
raise FileNotFoundError(f"PDF file not found at {pdf_path}")
|
| 103 |
+
|
| 104 |
+
path_parts = pdf_path.parts
|
| 105 |
+
disease = "HBV"
|
| 106 |
+
provider = path_parts[-2] if len(path_parts) >= 2 else "unknown"
|
| 107 |
+
if provider.lower() == "new_data":
|
| 108 |
+
provider = "unknown"
|
| 109 |
+
|
| 110 |
+
elements = partition_pdf(
|
| 111 |
+
filename=str(pdf_path),
|
| 112 |
+
strategy="hi_res",
|
| 113 |
+
infer_table_structure=True,
|
| 114 |
+
extract_images_in_pdf=True,
|
| 115 |
+
languages=["eng"],
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
documents = []
|
| 119 |
+
for idx, element in enumerate(elements):
|
| 120 |
+
if element.text.strip():
|
| 121 |
+
page_num = element.metadata.page_number if hasattr(element.metadata, 'page_number') else 1
|
| 122 |
+
element_type = element.category # e.g., "Table", "Title", "NarrativeText"
|
| 123 |
+
|
| 124 |
+
processed_doc = Document(
|
| 125 |
+
page_content=element.text,
|
| 126 |
+
metadata={
|
| 127 |
+
"source": pdf_path.name,
|
| 128 |
+
"disease": disease,
|
| 129 |
+
"provider": provider,
|
| 130 |
+
"page_number": page_num,
|
| 131 |
+
"element_type": element_type,
|
| 132 |
+
"element_id": idx
|
| 133 |
+
}
|
| 134 |
+
)
|
| 135 |
+
documents.append(processed_doc)
|
| 136 |
+
|
| 137 |
+
logger.info(f"Loaded {len(documents)} elements from PDF: {pdf_path.name} (Disease: {disease}, Provider: {provider})")
|
| 138 |
+
return documents
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error loading PDF documents from {pdf_path}: {str(e)}")
|
| 142 |
+
raise
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Usage
|
| 146 |
+
doc = load_pdf_documents(Path(r"data\processed_data\SASLT\SASLT 2021_20251026_171017.pdf"))
|