peihsin0715
commited on
Commit
·
7c447a5
1
Parent(s):
842d2fd
Add all project files for HF Spaces deployment
Browse files- Dockerfile +55 -0
- backend/requirements.txt +164 -0
- backend/server.py +600 -0
- backend/utils/finetune.py +120 -0
- backend/utils/sampling.py +120 -0
- backend/utils/utils.py +552 -0
- frontend/.gitignore +24 -0
- frontend/README.md +69 -0
- frontend/eslint.config.js +23 -0
- frontend/index.html +13 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +35 -0
- frontend/postcss.config.js +6 -0
- frontend/public/vite.svg +1 -0
- frontend/src/App.css +42 -0
- frontend/src/App.tsx +78 -0
- frontend/src/assets/react.svg +1 -0
- frontend/src/components/MetricCard.tsx +29 -0
- frontend/src/components/PipelineProgress.tsx +386 -0
- frontend/src/components/validators/DatasetValidator.tsx +78 -0
- frontend/src/components/validators/ModelValidator.tsx +80 -0
- frontend/src/constants/datasets.ts +10 -0
- frontend/src/constants/models.ts +11 -0
- frontend/src/hooks/JobRunnerProvider.tsx +121 -0
- frontend/src/hooks/useHFValidators.ts +129 -0
- frontend/src/hooks/useIterationData.ts +14 -0
- frontend/src/hooks/useJobRunner.ts +208 -0
- frontend/src/index.css +4 -0
- frontend/src/main.tsx +10 -0
- frontend/src/pages/ConfigPage.tsx +649 -0
- frontend/src/pages/ResultsPage.tsx +144 -0
- frontend/src/services/api.ts +70 -0
- frontend/src/services/hf.ts +12 -0
- frontend/src/types/index.ts +60 -0
- frontend/src/vite-env.d.ts +1 -0
- frontend/tailwind.config.js +11 -0
- frontend/tsconfig.app.json +27 -0
- frontend/tsconfig.json +8 -0
- frontend/tsconfig.node.json +25 -0
- frontend/vite.config.ts +7 -0
- nginx.conf.template +40 -0
Dockerfile
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------- Frontend build ----------
|
| 2 |
+
FROM node:20-bullseye AS fe
|
| 3 |
+
WORKDIR /app/frontend
|
| 4 |
+
COPY frontend/package*.json ./
|
| 5 |
+
RUN npm ci
|
| 6 |
+
COPY frontend/ ./
|
| 7 |
+
RUN npm run build
|
| 8 |
+
|
| 9 |
+
# ---------- Backend build ----------
|
| 10 |
+
FROM python:3.11-slim AS be
|
| 11 |
+
ENV PIP_NO_CACHE_DIR=1 PYTHONUNBUFFERED=1 PIP_PREFER_BINARY=1
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 14 |
+
build-essential gcc g++ gfortran make pkg-config \
|
| 15 |
+
libopenblas-dev liblapack-dev git \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
RUN python -m pip install --upgrade pip setuptools wheel
|
| 18 |
+
RUN pip install --index-url https://download.pytorch.org/whl/cpu "torch==2.3.1"
|
| 19 |
+
COPY backend/requirements.txt ./backend/requirements.txt
|
| 20 |
+
RUN sed -i 's/^[Tt]orch[[:space:]=<>!].*/# torch pinned separately (CPU)/' backend/requirements.txt || true
|
| 21 |
+
RUN pip install --only-binary=:all: blis || echo "Precompiled blis not available"
|
| 22 |
+
RUN pip install -r backend/requirements.txt || pip install -r backend/requirements.txt --no-deps
|
| 23 |
+
COPY backend/ ./backend/
|
| 24 |
+
|
| 25 |
+
# ---------- Runtime ----------
|
| 26 |
+
FROM python:3.11-slim AS runtime
|
| 27 |
+
ENV PYTHONUNBUFFERED=1 PIP_NO_CACHE_DIR=1 PORT=7860
|
| 28 |
+
WORKDIR /app
|
| 29 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 30 |
+
nginx supervisor ca-certificates \
|
| 31 |
+
libgomp1 libopenblas0 \
|
| 32 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
+
COPY --from=fe /app/frontend/dist /usr/share/nginx/html
|
| 34 |
+
COPY --from=be /usr/local /usr/local
|
| 35 |
+
COPY --from=be /app/backend /app/backend
|
| 36 |
+
RUN python -m pip install --no-cache-dir gunicorn
|
| 37 |
+
|
| 38 |
+
COPY nginx.conf.template /etc/nginx/nginx.conf
|
| 39 |
+
|
| 40 |
+
RUN mkdir -p /etc/supervisor/conf.d && \
|
| 41 |
+
printf "[program:api]\n\
|
| 42 |
+
command=gunicorn --workers 2 --threads 8 --timeout 0 --chdir /app/backend -b 0.0.0.0:5001 server:app\n\
|
| 43 |
+
priority=10\nautostart=true\nautorestart=true\n\
|
| 44 |
+
stdout_logfile=/dev/stdout\nstderr_logfile=/dev/stderr\n\
|
| 45 |
+
stdout_logfile_maxbytes=0\nstderr_logfile_maxbytes=0\n\n\
|
| 46 |
+
[program:nginx]\n\
|
| 47 |
+
command=nginx -g \"daemon off;\"\n\
|
| 48 |
+
priority=20\nautostart=true\nautorestart=true\n\
|
| 49 |
+
stdout_logfile=/dev/stdout\nstderr_logfile=/dev/stderr\n\
|
| 50 |
+
stdout_logfile_maxbytes=0\nstderr_logfile_maxbytes=0\n\n\
|
| 51 |
+
[supervisord]\nlogfile=/dev/stdout\nlogfile_maxbytes=0\nnodaemon=true\nuser=root\n" \
|
| 52 |
+
> /etc/supervisor/conf.d/app.conf
|
| 53 |
+
|
| 54 |
+
EXPOSE 7860
|
| 55 |
+
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/app.conf"]
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.10.1
|
| 2 |
+
aiohttp==3.9.1
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
alabaster>=0.7,<0.8
|
| 5 |
+
annotated-types==0.6.0
|
| 6 |
+
anyio==4.2.0
|
| 7 |
+
arrow==1.3.0
|
| 8 |
+
attrs==23.2.0
|
| 9 |
+
babel==2.17.0
|
| 10 |
+
beautifulsoup4==4.13.3
|
| 11 |
+
bibtexparser==1.4.3
|
| 12 |
+
boto3==1.36.14
|
| 13 |
+
bs4==0.0.2
|
| 14 |
+
catalogue==2.0.10
|
| 15 |
+
certifi==2023.11.17
|
| 16 |
+
charset-normalizer==3.3.2
|
| 17 |
+
click==8.1.7
|
| 18 |
+
cloudpathlib==0.20.0
|
| 19 |
+
colorama==0.4.6
|
| 20 |
+
confection==0.1.5
|
| 21 |
+
contourpy==1.3.1
|
| 22 |
+
cycler==0.12.1
|
| 23 |
+
cymem==2.0.11
|
| 24 |
+
dataclasses-json==0.6.3
|
| 25 |
+
datasets==2.18.0
|
| 26 |
+
defusedxml==0.7.1
|
| 27 |
+
Deprecated==1.2.18
|
| 28 |
+
dill==0.3.8
|
| 29 |
+
distro==1.9.0
|
| 30 |
+
fake-useragent==2.0.3
|
| 31 |
+
fastapi==0.109.0
|
| 32 |
+
filelock==3.17.0
|
| 33 |
+
Flask==2.0.3
|
| 34 |
+
Flask-Cors==3.0.10
|
| 35 |
+
fonttools==4.56.0
|
| 36 |
+
fpdf2==2.8.3
|
| 37 |
+
free-proxy==1.1.3
|
| 38 |
+
frozenlist==1.4.1
|
| 39 |
+
fsspec==2024.2.0
|
| 40 |
+
gensim==4.3.3
|
| 41 |
+
h11==0.14.0
|
| 42 |
+
hdbscan==0.8.40
|
| 43 |
+
hf-xet==1.1.9
|
| 44 |
+
httpcore==1.0.2
|
| 45 |
+
httpx==0.26.0
|
| 46 |
+
huggingface-hub==0.34.4
|
| 47 |
+
idna==3.6
|
| 48 |
+
imagesize==1.4.1
|
| 49 |
+
itsdangerous==2.2.0
|
| 50 |
+
Jinja2==3.0.3
|
| 51 |
+
jmespath==1.0.1
|
| 52 |
+
joblib==1.4.2
|
| 53 |
+
jsonpatch==1.33
|
| 54 |
+
jsonpointer==2.4
|
| 55 |
+
kiwisolver==1.4.8
|
| 56 |
+
langchain==0.1.1
|
| 57 |
+
langchain-community==0.0.13
|
| 58 |
+
langchain-core==0.1.13
|
| 59 |
+
langchain-openai==0.0.3
|
| 60 |
+
langcodes==3.5.0
|
| 61 |
+
langserve==0.0.39
|
| 62 |
+
langsmith==0.0.83
|
| 63 |
+
language_data==1.3.0
|
| 64 |
+
lxml==5.3.1
|
| 65 |
+
marisa-trie==1.2.1
|
| 66 |
+
markdown-it-py==3.0.0
|
| 67 |
+
MarkupSafe==3.0.2
|
| 68 |
+
marshmallow==3.20.2
|
| 69 |
+
matplotlib==3.10.0
|
| 70 |
+
mdurl==0.1.2
|
| 71 |
+
mpmath==1.3.0
|
| 72 |
+
multidict==6.0.4
|
| 73 |
+
multiprocess==0.70.16
|
| 74 |
+
murmurhash==1.0.12
|
| 75 |
+
mypy-extensions==1.0.0
|
| 76 |
+
narwhals==1.39.0
|
| 77 |
+
networkx==3.4.2
|
| 78 |
+
nltk==3.8.1
|
| 79 |
+
numpy==1.26.4
|
| 80 |
+
openai==1.9.0
|
| 81 |
+
orjson==3.9.12
|
| 82 |
+
outcome==1.3.0.post0
|
| 83 |
+
packaging==23.2
|
| 84 |
+
pandas==2.1.4
|
| 85 |
+
Pillow==9.5.0
|
| 86 |
+
plotly==6.0.1
|
| 87 |
+
preshed==3.0.9
|
| 88 |
+
psutil==7.0.0
|
| 89 |
+
pyarrow==14.0.2
|
| 90 |
+
pyarrow-hotfix==0.7
|
| 91 |
+
pyasn1==0.6.1
|
| 92 |
+
pydantic==2.5.3
|
| 93 |
+
pydantic_core==2.14.6
|
| 94 |
+
Pygments==2.19.1
|
| 95 |
+
pyparsing==3.2.1
|
| 96 |
+
PySocks==1.7.1
|
| 97 |
+
python-dateutil==2.9.0.post0
|
| 98 |
+
python-dotenv==1.0.0
|
| 99 |
+
pytz==2025.1
|
| 100 |
+
PyYAML==6.0.1
|
| 101 |
+
rank-bm25==0.2.2
|
| 102 |
+
regex==2023.12.25
|
| 103 |
+
requests==2.32.3
|
| 104 |
+
rich==13.9.4
|
| 105 |
+
roman-numerals-py==3.1.0
|
| 106 |
+
rsa==4.7.2
|
| 107 |
+
s3transfer==0.11.2
|
| 108 |
+
safetensors==0.5.2
|
| 109 |
+
SAGEDbias @ https://github.com/holistic-ai/SAGED-Bias/archive/8d5664387c58d94ffd10667c40493a5e460eaac6.zip
|
| 110 |
+
scholarly==1.7.11
|
| 111 |
+
scikit-learn==1.6.1
|
| 112 |
+
scipy==1.13.1
|
| 113 |
+
selenium==4.29.0
|
| 114 |
+
sentence-transformers==3.4.1
|
| 115 |
+
shellingham==1.5.4
|
| 116 |
+
six==1.17.0
|
| 117 |
+
smart-open==7.1.0
|
| 118 |
+
sniffio==1.3.0
|
| 119 |
+
snowballstemmer==2.2.0
|
| 120 |
+
sortedcontainers==2.4.0
|
| 121 |
+
soupsieve==2.6
|
| 122 |
+
spacy==3.8.7
|
| 123 |
+
spacy-legacy==3.0.12
|
| 124 |
+
spacy-loggers==1.0.5
|
| 125 |
+
Sphinx==7.2.6
|
| 126 |
+
sphinx-rtd-theme==3.0.2
|
| 127 |
+
sphinxcontrib-applehelp==2.0.0
|
| 128 |
+
sphinxcontrib-devhelp==2.0.0
|
| 129 |
+
sphinxcontrib-htmlhelp==2.1.0
|
| 130 |
+
sphinxcontrib-jquery==4.1
|
| 131 |
+
sphinxcontrib-jsmath==1.0.1
|
| 132 |
+
sphinxcontrib-qthelp==2.0.0
|
| 133 |
+
sphinxcontrib-serializinghtml==2.0.0
|
| 134 |
+
SQLAlchemy==2.0.25
|
| 135 |
+
srsly==2.5.1
|
| 136 |
+
starlette==0.35.1
|
| 137 |
+
sympy==1.13.1
|
| 138 |
+
tenacity==8.2.3
|
| 139 |
+
thinc==8.3.4
|
| 140 |
+
threadpoolctl==3.5.0
|
| 141 |
+
tiktoken==0.5.2
|
| 142 |
+
tokenizers==0.22.0
|
| 143 |
+
tqdm==4.67.1
|
| 144 |
+
transformers==4.56.1
|
| 145 |
+
trio==0.29.0
|
| 146 |
+
trio-websocket==0.12.2
|
| 147 |
+
typer==0.15.1
|
| 148 |
+
types-python-dateutil==2.9.0.20241206
|
| 149 |
+
typing-inspect==0.9.0
|
| 150 |
+
typing_extensions==4.12.2
|
| 151 |
+
tzdata==2025.1
|
| 152 |
+
urllib3==2.1.0
|
| 153 |
+
uvicorn==0.26.0
|
| 154 |
+
wasabi==1.1.3
|
| 155 |
+
weasel==0.4.1
|
| 156 |
+
websocket-client==1.8.0
|
| 157 |
+
Werkzeug==2.0.3
|
| 158 |
+
Wikipedia-API==0.7.3
|
| 159 |
+
wrapt==1.17.2
|
| 160 |
+
wsproto==1.2.0
|
| 161 |
+
xgboost==3.0.0
|
| 162 |
+
xxhash==3.5.0
|
| 163 |
+
yarl==1.9.4
|
| 164 |
+
seaborn
|
backend/server.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify, send_file, send_from_directory
|
| 2 |
+
from flask_cors import CORS
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import logging
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from typing import Optional, List, Dict, Any
|
| 11 |
+
from utils.utils import _ensure_plot_saved
|
| 12 |
+
|
| 13 |
+
os.environ["MPLBACKEND"] = "Agg"
|
| 14 |
+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
|
| 18 |
+
from utils.sampling import rank_sample
|
| 19 |
+
try:
|
| 20 |
+
from transformers import TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
| 21 |
+
print("✓ transformers training components imported")
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"✗ transformers training import failed: {e}")
|
| 24 |
+
def finetune(*args, **kwargs):
|
| 25 |
+
print("Warning: Transformers training components not available, skipping fine-tuning")
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
# 🤗 datasets
|
| 29 |
+
try:
|
| 30 |
+
from datasets import (
|
| 31 |
+
load_dataset,
|
| 32 |
+
load_dataset_builder,
|
| 33 |
+
get_dataset_config_names,
|
| 34 |
+
get_dataset_split_names,
|
| 35 |
+
Features,
|
| 36 |
+
)
|
| 37 |
+
print("✓ datasets imported")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"✗ datasets import failed: {e}")
|
| 40 |
+
raise
|
| 41 |
+
|
| 42 |
+
from utils.utils import (
|
| 43 |
+
generate_topk_samples,
|
| 44 |
+
evaluate_generated_outputs,
|
| 45 |
+
load_model_and_tokenizer,
|
| 46 |
+
generate_counterfactual_augmentations,
|
| 47 |
+
)
|
| 48 |
+
print("✓ utils imported")
|
| 49 |
+
|
| 50 |
+
app = Flask(__name__)
|
| 51 |
+
CORS(app)
|
| 52 |
+
|
| 53 |
+
_MODELS = {}
|
| 54 |
+
_CURRENT_DATASET = None
|
| 55 |
+
_GENERATION_RESULTS = None
|
| 56 |
+
|
| 57 |
+
@app.route('/data/<path:filename>')
|
| 58 |
+
def serve_data(filename):
|
| 59 |
+
import os
|
| 60 |
+
from flask import Response
|
| 61 |
+
|
| 62 |
+
print(f"[Static] Requested file: {filename}")
|
| 63 |
+
|
| 64 |
+
data_dir = os.path.abspath('data')
|
| 65 |
+
file_path = os.path.join(data_dir, filename)
|
| 66 |
+
|
| 67 |
+
print(f"[Static] Full path: {file_path}")
|
| 68 |
+
print(f"[Static] File exists: {os.path.exists(file_path)}")
|
| 69 |
+
|
| 70 |
+
if not os.path.exists(file_path):
|
| 71 |
+
return "File not found", 404
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
with open(file_path, 'rb') as f:
|
| 75 |
+
file_data = f.read()
|
| 76 |
+
|
| 77 |
+
if filename.endswith('.png'):
|
| 78 |
+
mimetype = 'image/png'
|
| 79 |
+
elif filename.endswith('.jpg') or filename.endswith('.jpeg'):
|
| 80 |
+
mimetype = 'image/jpeg'
|
| 81 |
+
elif filename.endswith('.csv'):
|
| 82 |
+
mimetype = 'text/csv'
|
| 83 |
+
else:
|
| 84 |
+
mimetype = 'application/octet-stream'
|
| 85 |
+
|
| 86 |
+
print(f"[Static] Serving {len(file_data)} bytes as {mimetype}")
|
| 87 |
+
|
| 88 |
+
return Response(file_data, mimetype=mimetype)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[Static] Error reading file: {e}")
|
| 92 |
+
return f"Error reading file: {str(e)}", 500
|
| 93 |
+
|
| 94 |
+
@app.route('/debug/files', methods=['GET'])
|
| 95 |
+
def debug_files():
|
| 96 |
+
try:
|
| 97 |
+
data_dir = os.path.abspath('data')
|
| 98 |
+
if not os.path.exists(data_dir):
|
| 99 |
+
return jsonify({"error": "Data directory not found", "path": data_dir})
|
| 100 |
+
|
| 101 |
+
files = []
|
| 102 |
+
for f in os.listdir(data_dir):
|
| 103 |
+
file_path = os.path.join(data_dir, f)
|
| 104 |
+
files.append({
|
| 105 |
+
"name": f,
|
| 106 |
+
"path": file_path,
|
| 107 |
+
"exists": os.path.exists(file_path),
|
| 108 |
+
"size": os.path.getsize(file_path) if os.path.exists(file_path) else 0
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
return jsonify({
|
| 112 |
+
"data_directory": data_dir,
|
| 113 |
+
"files": files
|
| 114 |
+
})
|
| 115 |
+
except Exception as e:
|
| 116 |
+
return jsonify({"error": str(e)})
|
| 117 |
+
|
| 118 |
+
def get_model(model_name: str):
|
| 119 |
+
if model_name in _MODELS:
|
| 120 |
+
print(f"Using cached model: {model_name}")
|
| 121 |
+
return _MODELS[model_name]
|
| 122 |
+
print(f"Loading new model: {model_name}")
|
| 123 |
+
tokenizer, model, device = load_model_and_tokenizer(model_name)
|
| 124 |
+
_MODELS[model_name] = (tokenizer, model, device)
|
| 125 |
+
return tokenizer, model, device
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.route('/health', methods=['GET'])
|
| 129 |
+
def health_check():
|
| 130 |
+
return jsonify({
|
| 131 |
+
"status": "healthy",
|
| 132 |
+
"timestamp": datetime.now().isoformat(),
|
| 133 |
+
"loaded_models": list(_MODELS.keys()),
|
| 134 |
+
"dataset_loaded": _CURRENT_DATASET is not None,
|
| 135 |
+
"generation_results_available": _GENERATION_RESULTS is not None
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _flatten_features(feats, prefix: str = "") -> List[str]:
|
| 140 |
+
cols: List[str] = []
|
| 141 |
+
try:
|
| 142 |
+
items = feats.items() if isinstance(feats, (Features, dict)) else feats.items()
|
| 143 |
+
except Exception:
|
| 144 |
+
try:
|
| 145 |
+
return list(feats.keys())
|
| 146 |
+
except Exception:
|
| 147 |
+
return cols
|
| 148 |
+
for name, sub in items:
|
| 149 |
+
full = f"{prefix}.{name}" if prefix else name
|
| 150 |
+
try:
|
| 151 |
+
if isinstance(sub, (Features, dict)):
|
| 152 |
+
cols += _flatten_features(sub, prefix=full)
|
| 153 |
+
else:
|
| 154 |
+
cols.append(full)
|
| 155 |
+
except Exception:
|
| 156 |
+
cols.append(full)
|
| 157 |
+
return cols
|
| 158 |
+
|
| 159 |
+
@lru_cache(maxsize=256)
|
| 160 |
+
def _get_dataset_fields_cached(dataset_id: str, config: Optional[str], split: str) -> List[str]:
|
| 161 |
+
try:
|
| 162 |
+
builder = load_dataset_builder(dataset_id, name=config)
|
| 163 |
+
feats = builder.info.features
|
| 164 |
+
fields = _flatten_features(feats)
|
| 165 |
+
return sorted(set(fields))
|
| 166 |
+
except Exception as e_builder:
|
| 167 |
+
try:
|
| 168 |
+
ds = load_dataset(dataset_id, name=config, split=split, streaming=True)
|
| 169 |
+
first = next(iter(ds.take(1)), None)
|
| 170 |
+
if first is None:
|
| 171 |
+
return []
|
| 172 |
+
fields = list(first.keys())
|
| 173 |
+
return sorted(set(fields))
|
| 174 |
+
except Exception as e_stream:
|
| 175 |
+
raise RuntimeError(f"builder_error={e_builder}; streaming_error={e_stream}")
|
| 176 |
+
|
| 177 |
+
@app.route('/dataset/fields', methods=['GET'])
|
| 178 |
+
def dataset_fields():
|
| 179 |
+
dataset_id = request.args.get('id')
|
| 180 |
+
cfg = request.args.get('config')
|
| 181 |
+
split = request.args.get('split', 'train')
|
| 182 |
+
if not dataset_id:
|
| 183 |
+
return jsonify({"error": "Missing required query param 'id'"}), 400
|
| 184 |
+
try:
|
| 185 |
+
fields = _get_dataset_fields_cached(dataset_id, cfg, split)
|
| 186 |
+
return jsonify({
|
| 187 |
+
"fields": fields,
|
| 188 |
+
"datasetId": dataset_id,
|
| 189 |
+
"config": cfg,
|
| 190 |
+
"split": split,
|
| 191 |
+
"source": "huggingface-builder" if fields else "unknown"
|
| 192 |
+
})
|
| 193 |
+
except Exception as e:
|
| 194 |
+
return jsonify({
|
| 195 |
+
"error": "Failed to fetch dataset fields",
|
| 196 |
+
"datasetId": dataset_id,
|
| 197 |
+
"config": cfg,
|
| 198 |
+
"split": split,
|
| 199 |
+
"detail": str(e)
|
| 200 |
+
}), 400
|
| 201 |
+
|
| 202 |
+
@app.route('/dataset/meta', methods=['GET'])
|
| 203 |
+
def dataset_meta():
|
| 204 |
+
dataset_id = request.args.get('id')
|
| 205 |
+
if not dataset_id:
|
| 206 |
+
return jsonify({"error": "Missing required query param 'id'"}), 400
|
| 207 |
+
try:
|
| 208 |
+
configs = get_dataset_config_names(dataset_id)
|
| 209 |
+
except Exception as e:
|
| 210 |
+
configs = []
|
| 211 |
+
logging.warning(f"get_dataset_config_names failed for {dataset_id}: {e}")
|
| 212 |
+
splits: List[str] = []
|
| 213 |
+
try:
|
| 214 |
+
if configs:
|
| 215 |
+
try:
|
| 216 |
+
b0 = load_dataset_builder(dataset_id, name=configs[0])
|
| 217 |
+
splits = sorted(list(b0.info.splits) or [])
|
| 218 |
+
except Exception:
|
| 219 |
+
splits = get_dataset_split_names(dataset_id, configs[0])
|
| 220 |
+
else:
|
| 221 |
+
try:
|
| 222 |
+
b = load_dataset_builder(dataset_id)
|
| 223 |
+
splits = sorted(list(b.info.splits) or [])
|
| 224 |
+
except Exception:
|
| 225 |
+
splits = get_dataset_split_names(dataset_id)
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logging.warning(f"get splits failed for {dataset_id}: {e}")
|
| 228 |
+
splits = []
|
| 229 |
+
return jsonify({
|
| 230 |
+
"datasetId": dataset_id,
|
| 231 |
+
"configs": configs,
|
| 232 |
+
"splits": splits
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
@app.route('/dataset/field-stats', methods=['GET'])
|
| 236 |
+
def dataset_field_stats():
|
| 237 |
+
dataset_id = request.args.get('id')
|
| 238 |
+
cfg = request.args.get('config')
|
| 239 |
+
split = request.args.get('split', 'train')
|
| 240 |
+
field = request.args.get('field')
|
| 241 |
+
subfield = request.args.get('subfield')
|
| 242 |
+
if not dataset_id or not field:
|
| 243 |
+
return jsonify({"error": "Missing required query params 'id' or 'field'"}), 400
|
| 244 |
+
try:
|
| 245 |
+
ds = load_dataset(dataset_id, name=cfg, split=split, streaming=True)
|
| 246 |
+
max_rows = 50000
|
| 247 |
+
counter: Dict[str, Any] = {}
|
| 248 |
+
print(f"[field-stats] Computing stats for '{field}'" + (f" → '{subfield}'" if subfield else ""))
|
| 249 |
+
for i, row in enumerate(ds):
|
| 250 |
+
if i >= max_rows:
|
| 251 |
+
break
|
| 252 |
+
main_val = row.get(field)
|
| 253 |
+
if main_val is None:
|
| 254 |
+
continue
|
| 255 |
+
if subfield:
|
| 256 |
+
sub_val = row.get(subfield)
|
| 257 |
+
if sub_val is None:
|
| 258 |
+
continue
|
| 259 |
+
counter.setdefault(main_val, {})
|
| 260 |
+
counter[main_val][sub_val] = counter[main_val].get(sub_val, 0) + 1
|
| 261 |
+
else:
|
| 262 |
+
counter[main_val] = counter.get(main_val, 0) + 1
|
| 263 |
+
return jsonify({
|
| 264 |
+
"field": field,
|
| 265 |
+
"subfield": subfield,
|
| 266 |
+
"datasetId": dataset_id,
|
| 267 |
+
"config": cfg,
|
| 268 |
+
"split": split,
|
| 269 |
+
"counts": counter
|
| 270 |
+
})
|
| 271 |
+
except Exception as e:
|
| 272 |
+
return jsonify({
|
| 273 |
+
"error": f"Failed to compute field stats: {str(e)}",
|
| 274 |
+
"datasetId": dataset_id,
|
| 275 |
+
"config": cfg,
|
| 276 |
+
"split": split,
|
| 277 |
+
"field": field,
|
| 278 |
+
"subfield": subfield
|
| 279 |
+
}), 500
|
| 280 |
+
|
| 281 |
+
def _parse_selected_groups_from_config(config: dict) -> List[str]:
|
| 282 |
+
raw = config.get('selectedCfFields', []) or []
|
| 283 |
+
out: List[str] = []
|
| 284 |
+
for s in raw:
|
| 285 |
+
s = (s or "").strip()
|
| 286 |
+
if not s:
|
| 287 |
+
continue
|
| 288 |
+
if "/" in s:
|
| 289 |
+
out.append(s.split("/")[-1])
|
| 290 |
+
else:
|
| 291 |
+
out.append(s)
|
| 292 |
+
seen = set()
|
| 293 |
+
uniq = []
|
| 294 |
+
for x in out:
|
| 295 |
+
if x not in seen:
|
| 296 |
+
uniq.append(x)
|
| 297 |
+
seen.add(x)
|
| 298 |
+
return uniq
|
| 299 |
+
|
| 300 |
+
def stratified_sample_by_category(df: pd.DataFrame, category_col: str, groups: List[str], total_n: Optional[int]) -> pd.DataFrame:
|
| 301 |
+
if total_n is None or total_n <= 0:
|
| 302 |
+
return df
|
| 303 |
+
|
| 304 |
+
groups_present = [g for g in groups if g in df[category_col].unique()]
|
| 305 |
+
if not groups_present:
|
| 306 |
+
return df.sample(n=min(total_n, len(df)), random_state=42)
|
| 307 |
+
|
| 308 |
+
base_each = max(1, total_n // max(1, len(groups_present)))
|
| 309 |
+
remainder = max(0, total_n - base_each * len(groups_present))
|
| 310 |
+
|
| 311 |
+
parts = []
|
| 312 |
+
for g in groups_present:
|
| 313 |
+
gdf = df[df[category_col] == g]
|
| 314 |
+
need = min(base_each, len(gdf))
|
| 315 |
+
if need > 0:
|
| 316 |
+
parts.append(gdf.sample(n=need, random_state=42))
|
| 317 |
+
|
| 318 |
+
i = 0
|
| 319 |
+
while remainder > 0 and len(df) > 0:
|
| 320 |
+
g = groups_present[i % len(groups_present)]
|
| 321 |
+
gdf = df[df[category_col] == g]
|
| 322 |
+
if len(gdf) > 0:
|
| 323 |
+
parts.append(gdf.sample(n=1, replace=(len(gdf) < 1), random_state=42 + remainder))
|
| 324 |
+
remainder -= 1
|
| 325 |
+
i += 1
|
| 326 |
+
|
| 327 |
+
out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=df.columns)
|
| 328 |
+
if len(out) < total_n and len(df) > len(out):
|
| 329 |
+
rest = min(total_n - len(out), len(df) - len(out))
|
| 330 |
+
pool = df.drop(out.index, errors="ignore")
|
| 331 |
+
if len(pool) > 0 and rest > 0:
|
| 332 |
+
out = pd.concat([out, pool.sample(n=min(rest, len(pool)), random_state=777)], ignore_index=True)
|
| 333 |
+
return out
|
| 334 |
+
|
| 335 |
+
def _pairwise_max_abs_diff(means: Dict[str, float]) -> float:
|
| 336 |
+
from itertools import combinations
|
| 337 |
+
keys = list(means.keys())
|
| 338 |
+
if len(keys) < 2:
|
| 339 |
+
return 0.0
|
| 340 |
+
diffs = [abs(means[a] - means[b]) for a, b in combinations(keys, 2)]
|
| 341 |
+
return float(max(diffs)) if diffs else 0.0
|
| 342 |
+
|
| 343 |
+
def _mean_by_cat(df: pd.DataFrame, cats: List[str], score_col: str = "sentiment_score") -> Dict[str, float]:
|
| 344 |
+
out: Dict[str, float] = {}
|
| 345 |
+
for c in cats:
|
| 346 |
+
sub = df[df["category"] == c]
|
| 347 |
+
if len(sub) > 0:
|
| 348 |
+
out[c] = float(sub[score_col].mean())
|
| 349 |
+
return out
|
| 350 |
+
|
| 351 |
+
@app.route('/pipeline', methods=['POST'])
|
| 352 |
+
def run_pipeline():
|
| 353 |
+
"""Run the complete pipeline with frontend JobConfig format"""
|
| 354 |
+
data = request.get_json() or {}
|
| 355 |
+
config = data.get('config', data) or {}
|
| 356 |
+
print("[DEBUG] Received config:", config)
|
| 357 |
+
|
| 358 |
+
dataset_id = config.get('dataset') or "AmazonScience/bold"
|
| 359 |
+
model_name = config.get('languageModel', 'openai-community/gpt2')
|
| 360 |
+
top_k = int(config.get('k', 5))
|
| 361 |
+
dataset_limit_raw = config.get('datasetLimit')
|
| 362 |
+
dataset_limit = int(dataset_limit_raw) if dataset_limit_raw is not None else None
|
| 363 |
+
num_cf_per_row = int(config.get('numCounterfactuals') or 3)
|
| 364 |
+
tau = float(config.get('tau', 0.1))
|
| 365 |
+
iterations = int(config.get('iterations', 1000))
|
| 366 |
+
metric_target = config.get('metrictarget')
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
results = {}
|
| 370 |
+
global _CURRENT_DATASET, _GENERATION_RESULTS
|
| 371 |
+
|
| 372 |
+
print("Pipeline Step 1: Loading data...")
|
| 373 |
+
ds = load_dataset(dataset_id, split="train")
|
| 374 |
+
df_full = pd.DataFrame(ds)[["domain", "name", "category", "prompts", "wikipedia"]].copy()
|
| 375 |
+
|
| 376 |
+
selected_groups = _parse_selected_groups_from_config(config)
|
| 377 |
+
present_all = sorted(df_full["category"].dropna().unique().tolist())
|
| 378 |
+
|
| 379 |
+
if selected_groups:
|
| 380 |
+
selected_groups = [g for g in selected_groups if g in present_all]
|
| 381 |
+
if len(selected_groups) < 2:
|
| 382 |
+
print(f"[Filter] Requested groups not enough in dataset (have {selected_groups}); fallback to ALL categories")
|
| 383 |
+
selected_groups = []
|
| 384 |
+
else:
|
| 385 |
+
print("[Filter] No groups requested from frontend; will use categories present after generation.")
|
| 386 |
+
|
| 387 |
+
df_pool = df_full[df_full["category"].isin(selected_groups)].copy() if selected_groups else df_full.copy()
|
| 388 |
+
|
| 389 |
+
df = stratified_sample_by_category(
|
| 390 |
+
df=df_pool,
|
| 391 |
+
category_col="category",
|
| 392 |
+
groups=selected_groups if selected_groups else sorted(df_pool["category"].unique().tolist()),
|
| 393 |
+
total_n=dataset_limit
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
print(f"[Pool] pool_size={len(df_pool)}, sampled={len(df)}")
|
| 397 |
+
print(f"[Pool] categories in pool: {sorted(df_pool['category'].unique().tolist())}")
|
| 398 |
+
print(f"[Pool] categories in sample: {sorted(df['category'].unique().tolist())}")
|
| 399 |
+
|
| 400 |
+
_CURRENT_DATASET = df
|
| 401 |
+
results['data_loaded'] = len(df)
|
| 402 |
+
print(f"Dataset loaded: {len(df)} rows")
|
| 403 |
+
|
| 404 |
+
print("Pipeline Step 2: Loading model...")
|
| 405 |
+
tokenizer, model, device = get_model(model_name)
|
| 406 |
+
results['model_loaded'] = model_name
|
| 407 |
+
|
| 408 |
+
print(f"Pipeline Step 3: Generating samples for {len(df)} entries...")
|
| 409 |
+
generation_results = generate_topk_samples(model, _CURRENT_DATASET, tokenizer, device, top_k=top_k)
|
| 410 |
+
task = config.get('classificationTask', 'sentiment')
|
| 411 |
+
tox_choice = config.get('toxicityModelChoice', 'detoxify')
|
| 412 |
+
|
| 413 |
+
evaluated_results = evaluate_generated_outputs(
|
| 414 |
+
generation_results, device,
|
| 415 |
+
task=task,
|
| 416 |
+
toxicity_model_choice=tox_choice
|
| 417 |
+
)
|
| 418 |
+
_GENERATION_RESULTS = evaluated_results
|
| 419 |
+
|
| 420 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 421 |
+
os.makedirs("data", exist_ok=True)
|
| 422 |
+
output_file = f"data/pipeline_generation_{timestamp}.csv"
|
| 423 |
+
evaluated_results.to_csv(output_file, index=False)
|
| 424 |
+
results['generation_file'] = output_file
|
| 425 |
+
results['generation_samples'] = len(evaluated_results)
|
| 426 |
+
print("Pipeline Step 3.5: Counterfactual augmentation...")
|
| 427 |
+
augmented_results = generate_counterfactual_augmentations(
|
| 428 |
+
evaluated_results,
|
| 429 |
+
text_col="generated",
|
| 430 |
+
name_col="name",
|
| 431 |
+
category_col="category",
|
| 432 |
+
num_cf_per_row=num_cf_per_row
|
| 433 |
+
)
|
| 434 |
+
augmented_file = f"data/pipeline_generation_cf_augmented_{timestamp}.csv"
|
| 435 |
+
augmented_results.to_csv(augmented_file, index=False)
|
| 436 |
+
results['counterfactual_file'] = augmented_file
|
| 437 |
+
results['counterfactual_added'] = len(augmented_results) - len(evaluated_results)
|
| 438 |
+
results['counterfactual_total'] = len(augmented_results)
|
| 439 |
+
|
| 440 |
+
present_after_gen = sorted(evaluated_results["category"].dropna().unique().tolist())
|
| 441 |
+
if not selected_groups:
|
| 442 |
+
selected_groups_used = present_after_gen
|
| 443 |
+
else:
|
| 444 |
+
selected_groups_used = [g for g in selected_groups if g in present_after_gen]
|
| 445 |
+
if len(selected_groups_used) < 2:
|
| 446 |
+
print(f"[Sampling] After generation only {selected_groups_used} present; expanding to all present categories")
|
| 447 |
+
selected_groups_used = present_after_gen
|
| 448 |
+
|
| 449 |
+
print(f"[Sampling] Using groups: {selected_groups_used}")
|
| 450 |
+
|
| 451 |
+
print("Debug: Checking data before sampling...")
|
| 452 |
+
print(f"Total evaluated results: {len(evaluated_results)}")
|
| 453 |
+
print(f"Categories in data: {present_after_gen}")
|
| 454 |
+
print(f"Names in data: {evaluated_results['name'].unique()}")
|
| 455 |
+
|
| 456 |
+
for cat in selected_groups_used:
|
| 457 |
+
cat_count = int((evaluated_results["category"] == cat).sum())
|
| 458 |
+
print(f"Category '{cat}': {cat_count} samples")
|
| 459 |
+
|
| 460 |
+
print(f"Pipeline Step 4: Rank sampling on original evaluated results...(iterations={iterations}, temp={tau})")
|
| 461 |
+
try:
|
| 462 |
+
best_sent_subset = rank_sample(evaluated_results, num_samples=iterations, temp=tau, target_value=metric_target)
|
| 463 |
+
except (ValueError, IndexError) as e:
|
| 464 |
+
print(f"Sampling failed: {e}")
|
| 465 |
+
mid_point = len(evaluated_results) // 2
|
| 466 |
+
best_sent_subset = evaluated_results.iloc[:mid_point].copy()
|
| 467 |
+
|
| 468 |
+
sent_file = f"data/pipeline_sent_subset_{timestamp}.csv"
|
| 469 |
+
best_sent_subset.to_csv(sent_file, index=False)
|
| 470 |
+
|
| 471 |
+
print(f"Pipeline Step 5: Rank sampling on CF-augmented results...(iterations={iterations}, temp={tau})")
|
| 472 |
+
try:
|
| 473 |
+
cf_best_sent_subset = rank_sample(augmented_results, num_samples=iterations, temp=tau, target_value=metric_target)
|
| 474 |
+
except (ValueError, IndexError) as e:
|
| 475 |
+
print(f"CF Sampling failed: {e}")
|
| 476 |
+
mid_point = len(augmented_results) // 2
|
| 477 |
+
cf_best_sent_subset = augmented_results.iloc[:mid_point].copy()
|
| 478 |
+
|
| 479 |
+
cf_sent_file = f"data/pipeline_cf_sent_subset_{timestamp}.csv"
|
| 480 |
+
cf_best_sent_subset.to_csv(cf_sent_file, index=False)
|
| 481 |
+
|
| 482 |
+
orig_means = _mean_by_cat(best_sent_subset, selected_groups_used)
|
| 483 |
+
final_mean_diff = _pairwise_max_abs_diff(orig_means)
|
| 484 |
+
|
| 485 |
+
cf_means = _mean_by_cat(cf_best_sent_subset, selected_groups_used)
|
| 486 |
+
cf_final_mean_diff = _pairwise_max_abs_diff(cf_means)
|
| 487 |
+
|
| 488 |
+
print("Pipeline Step 6: Plotting distributions...")
|
| 489 |
+
|
| 490 |
+
def _safe(s: str) -> str:
|
| 491 |
+
import re
|
| 492 |
+
return re.sub(r"[^A-Za-z0-9_.-]+", "_", s)
|
| 493 |
+
|
| 494 |
+
orig_sent_title = _safe(f"{timestamp}_original_distribution")
|
| 495 |
+
cf_sent_title = _safe(f"{timestamp}_cf_distribution")
|
| 496 |
+
|
| 497 |
+
score_col = None
|
| 498 |
+
for c in [
|
| 499 |
+
"sentiment_score", "regard_score", "toxicity_score",
|
| 500 |
+
"stereotype_gender_score", "stereotype_religion_score",
|
| 501 |
+
"stereotype_profession_score", "stereotype_race_score",
|
| 502 |
+
"personality_score",
|
| 503 |
+
]:
|
| 504 |
+
if c in best_sent_subset.columns:
|
| 505 |
+
score_col = c
|
| 506 |
+
break
|
| 507 |
+
if score_col is None:
|
| 508 |
+
raise KeyError(f"No score column found. Available: {list(best_sent_subset.columns)}")
|
| 509 |
+
|
| 510 |
+
orig_path = _ensure_plot_saved(
|
| 511 |
+
best_sent_subset, score_col, orig_sent_title,
|
| 512 |
+
group_col="category", target=metric_target
|
| 513 |
+
)
|
| 514 |
+
cf_path = _ensure_plot_saved(
|
| 515 |
+
cf_best_sent_subset, score_col, cf_sent_title,
|
| 516 |
+
group_col="category", target=metric_target
|
| 517 |
+
)
|
| 518 |
+
print("[Plot check exists]", orig_path, os.path.exists(orig_path))
|
| 519 |
+
print("[Plot check exists]", cf_path, os.path.exists(cf_path))
|
| 520 |
+
|
| 521 |
+
results['plots'] = {
|
| 522 |
+
'original_sentiment': f"/data/{orig_sent_title}.png",
|
| 523 |
+
'counterfactual_sentiment': f"/data/{cf_sent_title}.png",
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
print("[Plot urls]", results['plots'])
|
| 527 |
+
|
| 528 |
+
if config.get("enableFineTuning"):
|
| 529 |
+
print("Pipeline Step 7: Fine-tuning enabled, starting training...")
|
| 530 |
+
|
| 531 |
+
ft_cfg = config.get("finetuneParams", {}) or {}
|
| 532 |
+
epochs = int(ft_cfg.get("epochs", 3))
|
| 533 |
+
batch_size = int(ft_cfg.get("batchSize", 8))
|
| 534 |
+
lr = float(ft_cfg.get("learningRate", 5e-5))
|
| 535 |
+
|
| 536 |
+
input_csv = augmented_file
|
| 537 |
+
ft_output_dir = f"data/ft_{timestamp}"
|
| 538 |
+
os.makedirs(ft_output_dir, exist_ok=True)
|
| 539 |
+
|
| 540 |
+
try:
|
| 541 |
+
from utils.finetune import finetune_gpt2_from_csv
|
| 542 |
+
finetune_gpt2_from_csv(
|
| 543 |
+
csv_path=input_csv,
|
| 544 |
+
output_dir=ft_output_dir,
|
| 545 |
+
epochs=epochs,
|
| 546 |
+
batch_size=batch_size,
|
| 547 |
+
lr=lr
|
| 548 |
+
)
|
| 549 |
+
print(f"[Fine-tune] Saved fine-tuned model to {ft_output_dir}")
|
| 550 |
+
results["finetuned_model_dir"] = ft_output_dir
|
| 551 |
+
zip_base = f"data/ft_{timestamp}"
|
| 552 |
+
import shutil
|
| 553 |
+
zip_path = shutil.make_archive(zip_base, 'zip', ft_output_dir)
|
| 554 |
+
results["finetuned_model_zip"] = f"/data/{os.path.basename(zip_path)}"
|
| 555 |
+
except Exception as fe:
|
| 556 |
+
print(f"[Fine-tune] Failed: {fe}")
|
| 557 |
+
results["finetuned_model_error"] = str(fe)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
results.update({
|
| 561 |
+
'sampling_method': 'rank_sentiment_only',
|
| 562 |
+
'used_groups': selected_groups_used,
|
| 563 |
+
'sentiment_subset_file': sent_file,
|
| 564 |
+
'cf_sentiment_subset_file': cf_sent_file,
|
| 565 |
+
'sentiment_subset_size': len(best_sent_subset),
|
| 566 |
+
'cf_sentiment_subset_size': len(cf_best_sent_subset),
|
| 567 |
+
'config_used': config,
|
| 568 |
+
'metrics': {
|
| 569 |
+
'finalMeanDiff': final_mean_diff,
|
| 570 |
+
'cfFinalMeanDiff': cf_final_mean_diff,
|
| 571 |
+
'reductionPct': (0.0 if final_mean_diff == 0 else max(0.0, (final_mean_diff - cf_final_mean_diff) / abs(final_mean_diff) * 100.0)),
|
| 572 |
+
'stableCoverage': 100.0
|
| 573 |
+
}
|
| 574 |
+
})
|
| 575 |
+
|
| 576 |
+
return jsonify({
|
| 577 |
+
"status": "success",
|
| 578 |
+
"message": "Complete pipeline executed successfully (with counterfactual augmentation)",
|
| 579 |
+
"results": results,
|
| 580 |
+
"timestamp": timestamp
|
| 581 |
+
})
|
| 582 |
+
|
| 583 |
+
except Exception as e:
|
| 584 |
+
print(f"Error in pipeline: {str(e)}")
|
| 585 |
+
return jsonify({
|
| 586 |
+
"status": "error",
|
| 587 |
+
"message": f"Pipeline failed: {str(e)}"
|
| 588 |
+
}), 500
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
if __name__ == '__main__':
|
| 592 |
+
os.makedirs("data", exist_ok=True)
|
| 593 |
+
print("Starting minimal Flask server...")
|
| 594 |
+
print("Available endpoints:")
|
| 595 |
+
print(" GET /health - Health check")
|
| 596 |
+
print(" GET /dataset/fields?id=<hf_id>[&config=...][&split=...] - List dataset fields")
|
| 597 |
+
print(" GET /dataset/field-stats?id=...&field=... - Get value distribution of a field")
|
| 598 |
+
print(" GET /dataset/meta?id=<hf_id> - List configs/splits")
|
| 599 |
+
print(" POST /pipeline - Run complete pipeline")
|
| 600 |
+
app.run(host='0.0.0.0', port=5001, debug=True, threaded=True)
|
backend/utils/finetune.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, math, random
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from transformers import (AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling,
|
| 6 |
+
Trainer, TrainingArguments)
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 10 |
+
PEFT_AVAILABLE = True
|
| 11 |
+
except Exception:
|
| 12 |
+
PEFT_AVAILABLE = False
|
| 13 |
+
|
| 14 |
+
def build_text_column(df: pd.DataFrame) -> pd.Series:
|
| 15 |
+
cols = [c.lower() for c in df.columns]
|
| 16 |
+
lower_map = {c.lower(): c for c in df.columns}
|
| 17 |
+
if 'text' in cols:
|
| 18 |
+
return df[lower_map['text']].astype(str)
|
| 19 |
+
if 'prompt' in cols and 'generated' in cols:
|
| 20 |
+
pcol = lower_map['prompt']; rcol = lower_map['generated']
|
| 21 |
+
return df.apply(lambda r: f"### Instruction:\n{r[pcol]}\n\n### Response:\n{r[rcol]}\n", axis=1)
|
| 22 |
+
|
| 23 |
+
if 'generated' in cols:
|
| 24 |
+
return df[lower_map['generated']].astype(str)
|
| 25 |
+
|
| 26 |
+
raise ValueError("CSV 缺少可用欄位:請提供 text,或 prompt+generated,或 generated。")
|
| 27 |
+
|
| 28 |
+
def finetune_gpt2_from_csv(
|
| 29 |
+
csv_path: str,
|
| 30 |
+
base_model: str = "gpt2",
|
| 31 |
+
output_dir: str = "data/ft_gpt2_out",
|
| 32 |
+
train_split: float = 0.9,
|
| 33 |
+
epochs: int = 3,
|
| 34 |
+
lr: float = 5e-5,
|
| 35 |
+
batch_size: int = 2,
|
| 36 |
+
use_lora: bool = False,
|
| 37 |
+
lora_r: int = 8,
|
| 38 |
+
lora_alpha: int = 16,
|
| 39 |
+
lora_dropout: float = 0.05,
|
| 40 |
+
seed: int = 42,
|
| 41 |
+
max_length: int = 512,
|
| 42 |
+
) -> dict:
|
| 43 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 44 |
+
random.seed(seed); torch.manual_seed(seed)
|
| 45 |
+
|
| 46 |
+
df = pd.read_csv(csv_path)
|
| 47 |
+
texts = build_text_column(df).fillna("").tolist()
|
| 48 |
+
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 50 |
+
if tokenizer.pad_token is None:
|
| 51 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 52 |
+
|
| 53 |
+
model = AutoModelForCausalLM.from_pretrained(base_model)
|
| 54 |
+
|
| 55 |
+
if use_lora:
|
| 56 |
+
if not PEFT_AVAILABLE:
|
| 57 |
+
print("PEFT 未安裝,改為全參數微調")
|
| 58 |
+
else:
|
| 59 |
+
lconf = LoraConfig(
|
| 60 |
+
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
| 61 |
+
task_type=TaskType.CAUSAL_LM, target_modules=["c_attn","c_proj","q_attn"] # 視模型而定
|
| 62 |
+
)
|
| 63 |
+
model = get_peft_model(model, lconf)
|
| 64 |
+
|
| 65 |
+
def tokenize(example_texts):
|
| 66 |
+
return tokenizer(example_texts, truncation=True, max_length=max_length)
|
| 67 |
+
|
| 68 |
+
split_idx = int(len(texts) * train_split)
|
| 69 |
+
train_texts, val_texts = texts[:split_idx], texts[split_idx:] or texts[: max(1, len(texts)//10)]
|
| 70 |
+
|
| 71 |
+
train_enc = tokenize(train_texts)
|
| 72 |
+
val_enc = tokenize(val_texts)
|
| 73 |
+
|
| 74 |
+
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 75 |
+
|
| 76 |
+
class SimpleDS(torch.utils.data.Dataset):
|
| 77 |
+
def __init__(self, enc): self.enc = enc
|
| 78 |
+
def __len__(self): return len(self.enc["input_ids"])
|
| 79 |
+
def __getitem__(self, idx):
|
| 80 |
+
return {k: torch.tensor(v[idx]) for k, v in self.enc.items()}
|
| 81 |
+
|
| 82 |
+
train_ds, val_ds = SimpleDS(train_enc), SimpleDS(val_enc)
|
| 83 |
+
|
| 84 |
+
args = TrainingArguments(
|
| 85 |
+
output_dir=output_dir,
|
| 86 |
+
per_device_train_batch_size=batch_size,
|
| 87 |
+
per_device_eval_batch_size=batch_size,
|
| 88 |
+
num_train_epochs=epochs,
|
| 89 |
+
learning_rate=lr,
|
| 90 |
+
warmup_ratio=0.03,
|
| 91 |
+
weight_decay=0.01,
|
| 92 |
+
logging_steps=20,
|
| 93 |
+
eval_strategy="steps",
|
| 94 |
+
eval_steps=100,
|
| 95 |
+
save_strategy="steps",
|
| 96 |
+
save_steps=100,
|
| 97 |
+
save_total_limit=2,
|
| 98 |
+
fp16=torch.cuda.is_available(),
|
| 99 |
+
bf16=torch.cuda.is_bf16_supported() if hasattr(torch.cuda, "is_bf16_supported") else False,
|
| 100 |
+
report_to=[],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
trainer = Trainer(
|
| 104 |
+
model=model,
|
| 105 |
+
args=args,
|
| 106 |
+
train_dataset=train_ds,
|
| 107 |
+
eval_dataset=val_ds,
|
| 108 |
+
data_collator=collator,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
trainer.train()
|
| 112 |
+
trainer.save_model(output_dir)
|
| 113 |
+
tokenizer.save_pretrained(output_dir)
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"output_dir": output_dir,
|
| 117 |
+
"train_size": len(train_ds),
|
| 118 |
+
"eval_size": len(val_ds),
|
| 119 |
+
"perplexity": math.exp(trainer.evaluate()["eval_loss"]) if "eval_loss" in trainer.evaluate() else None
|
| 120 |
+
}
|
backend/utils/sampling.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
def rank_sample(
|
| 7 |
+
df: pd.DataFrame,
|
| 8 |
+
name_col: str = "name",
|
| 9 |
+
category_col: str = "category",
|
| 10 |
+
sentiment_col: str = "sentiment_score",
|
| 11 |
+
groups: Optional[List[str]] = None,
|
| 12 |
+
num_samples: int = 1000,
|
| 13 |
+
temp: float = 1.0,
|
| 14 |
+
target_value: float = 0.5,
|
| 15 |
+
) -> pd.DataFrame:
|
| 16 |
+
|
| 17 |
+
df = df.copy()
|
| 18 |
+
|
| 19 |
+
for col in [name_col, category_col, sentiment_col]:
|
| 20 |
+
if col not in df.columns:
|
| 21 |
+
raise ValueError(f"Column '{col}' not found in DataFrame")
|
| 22 |
+
|
| 23 |
+
df = df.dropna(subset=[name_col, category_col, sentiment_col])
|
| 24 |
+
|
| 25 |
+
if groups:
|
| 26 |
+
available_groups = df[category_col].unique()
|
| 27 |
+
valid_groups = [g for g in groups if g in available_groups]
|
| 28 |
+
if len(valid_groups) < 2:
|
| 29 |
+
print(f"Warning: Only {len(valid_groups)} groups available from {groups}")
|
| 30 |
+
groups = None
|
| 31 |
+
else:
|
| 32 |
+
groups = valid_groups
|
| 33 |
+
df = df[df[category_col].isin(groups)].copy()
|
| 34 |
+
|
| 35 |
+
final_groups = df[category_col].unique()
|
| 36 |
+
if len(final_groups) < 2:
|
| 37 |
+
print(f"Error: Only {len(final_groups)} groups in data, need at least 2")
|
| 38 |
+
return df.groupby(name_col).first().reset_index()
|
| 39 |
+
|
| 40 |
+
print(f"Sampling with groups: {sorted(final_groups)}")
|
| 41 |
+
print(f"Target value for deviation calculation: {target_value}")
|
| 42 |
+
|
| 43 |
+
df["sentiment_deviation"] = (df[sentiment_col] - target_value).abs()
|
| 44 |
+
df["sentiment_rank"] = df.groupby(name_col)["sentiment_deviation"].rank(method="first", ascending=True)
|
| 45 |
+
|
| 46 |
+
def softmax_weights(ranks: np.ndarray, temp: float) -> np.ndarray:
|
| 47 |
+
t = float(temp) if temp and temp > 1e-8 else 1e-8
|
| 48 |
+
x = -ranks / t
|
| 49 |
+
x = x - np.max(x)
|
| 50 |
+
exps = np.exp(x)
|
| 51 |
+
s = exps.sum()
|
| 52 |
+
return exps / s if np.isfinite(s) and s > 0 else np.ones_like(exps) / len(exps)
|
| 53 |
+
|
| 54 |
+
def objective_max_pairwise_diff(frame: pd.DataFrame) -> float:
|
| 55 |
+
g = frame.groupby(category_col)[sentiment_col].mean().dropna()
|
| 56 |
+
if len(g) < 2:
|
| 57 |
+
return np.inf
|
| 58 |
+
vals = g.values
|
| 59 |
+
diffs = np.abs(vals[:, None] - vals[None, :])
|
| 60 |
+
return float(np.max(diffs))
|
| 61 |
+
|
| 62 |
+
best_subset = None
|
| 63 |
+
best_obj = np.inf
|
| 64 |
+
valid_samples = 0
|
| 65 |
+
|
| 66 |
+
unique_names = df[name_col].nunique()
|
| 67 |
+
print(f"Total unique names: {unique_names}")
|
| 68 |
+
|
| 69 |
+
for i in tqdm(range(num_samples), desc="Sampling"):
|
| 70 |
+
try:
|
| 71 |
+
sampled_rows = []
|
| 72 |
+
|
| 73 |
+
for name, group in df.groupby(name_col):
|
| 74 |
+
if len(group) == 0:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
ranks = group["sentiment_rank"].to_numpy(dtype=float)
|
| 78 |
+
if len(ranks) == 0:
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
w = softmax_weights(ranks, temp=temp)
|
| 82 |
+
idx = np.random.choice(group.index, p=w)
|
| 83 |
+
sampled_rows.append(df.loc[idx])
|
| 84 |
+
|
| 85 |
+
if len(sampled_rows) == 0:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
subset = pd.DataFrame(sampled_rows)
|
| 89 |
+
|
| 90 |
+
subset_groups = subset[category_col].unique()
|
| 91 |
+
if len(subset_groups) < 2:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
obj = objective_max_pairwise_diff(subset)
|
| 95 |
+
|
| 96 |
+
if np.isfinite(obj):
|
| 97 |
+
valid_samples += 1
|
| 98 |
+
if obj < best_obj:
|
| 99 |
+
best_obj = obj
|
| 100 |
+
best_subset = subset.copy()
|
| 101 |
+
|
| 102 |
+
if valid_samples % 100 == 0 or valid_samples <= 10:
|
| 103 |
+
group_means = subset.groupby(category_col)[sentiment_col].mean()
|
| 104 |
+
print(f"Sample {valid_samples}: obj={obj:.4f}, groups={dict(group_means)}")
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"Error in sample {i}: {e}")
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
print(f"Valid samples: {valid_samples}/{num_samples}")
|
| 111 |
+
print(f"Best objective: {best_obj:.4f}")
|
| 112 |
+
|
| 113 |
+
if best_subset is None or len(best_subset) == 0:
|
| 114 |
+
print("Warning: No valid samples found, returning fallback subset")
|
| 115 |
+
best_subset = df.groupby(name_col).first().reset_index()
|
| 116 |
+
|
| 117 |
+
final_group_counts = best_subset[category_col].value_counts()
|
| 118 |
+
print(f"Final subset group distribution: {dict(final_group_counts)}")
|
| 119 |
+
|
| 120 |
+
return best_subset
|
backend/utils/utils.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import TextDataset,DataCollatorForLanguageModeling,Trainer,TrainingArguments
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
AutoModelForCausalLM,
|
| 15 |
+
GPT2LMHeadModel,
|
| 16 |
+
GPT2Tokenizer,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def load_model_and_tokenizer(model_name: str):
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
device = torch.device("cuda")
|
| 22 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # macOS Apple Silicon
|
| 23 |
+
device = torch.device("mps")
|
| 24 |
+
else:
|
| 25 |
+
device = torch.device("cpu")
|
| 26 |
+
|
| 27 |
+
gpt2_aliases = {"gpt2", "openai-community/gpt2", "holistic-ai/gpt2-EMGSD"}
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if model_name in gpt2_aliases:
|
| 31 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 32 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 33 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 34 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 35 |
+
if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
|
| 36 |
+
model.config.pad_token_id = model.config.eos_token_id
|
| 37 |
+
else:
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 39 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 40 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 41 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 42 |
+
if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
|
| 43 |
+
model.config.pad_token_id = model.config.eos_token_id
|
| 44 |
+
|
| 45 |
+
model.to(device)
|
| 46 |
+
return tokenizer, model, device
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise RuntimeError(f"Failed to load model '{model_name}': {e}")
|
| 49 |
+
|
| 50 |
+
def finetune(train_texts, tokenizer, model, num_epochs=20, output_dir='./data'):
|
| 51 |
+
train_path = f"data/train.txt"
|
| 52 |
+
|
| 53 |
+
with open(train_path, "w", encoding="utf-8") as f:
|
| 54 |
+
for text in train_texts:
|
| 55 |
+
f.write(text.strip() + "\n")
|
| 56 |
+
|
| 57 |
+
train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_path, block_size=128)
|
| 58 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 59 |
+
|
| 60 |
+
training_args = TrainingArguments(
|
| 61 |
+
output_dir=output_dir,
|
| 62 |
+
overwrite_output_dir=True,
|
| 63 |
+
per_device_train_batch_size=1,
|
| 64 |
+
num_train_epochs=num_epochs,
|
| 65 |
+
save_steps=500,
|
| 66 |
+
save_total_limit=2,
|
| 67 |
+
logging_dir='./logs',
|
| 68 |
+
logging_steps=10,
|
| 69 |
+
report_to="none"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
trainer = Trainer(
|
| 73 |
+
model=model,
|
| 74 |
+
args=training_args,
|
| 75 |
+
data_collator=data_collator,
|
| 76 |
+
train_dataset=train_dataset,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
trainer.train()
|
| 80 |
+
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
def generate_topk_samples(model, df_table, tokenizer, device, top_k=10):
|
| 84 |
+
model.eval()
|
| 85 |
+
flat_results = []
|
| 86 |
+
|
| 87 |
+
df_table["prompts"] = df_table["prompts"].apply(lambda x: x[0] if isinstance(x, list) else x)
|
| 88 |
+
|
| 89 |
+
for idx, row in tqdm(df_table.iterrows(), total=len(df_table), desc="Generating samples"):
|
| 90 |
+
prompt = row["prompts"]
|
| 91 |
+
|
| 92 |
+
inputs = tokenizer(
|
| 93 |
+
prompt,
|
| 94 |
+
return_tensors="pt",
|
| 95 |
+
truncation=True,
|
| 96 |
+
padding=True
|
| 97 |
+
).to(device)
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
outputs = model.generate(
|
| 101 |
+
input_ids=inputs["input_ids"],
|
| 102 |
+
attention_mask=inputs["attention_mask"],
|
| 103 |
+
do_sample=True,
|
| 104 |
+
top_k=top_k,
|
| 105 |
+
max_new_tokens=20,
|
| 106 |
+
top_p=1.0,
|
| 107 |
+
num_return_sequences=top_k,
|
| 108 |
+
pad_token_id=tokenizer.eos_token_id
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
for out in outputs:
|
| 112 |
+
full_text = tokenizer.decode(out, skip_special_tokens=True).strip()
|
| 113 |
+
flat_results.append({
|
| 114 |
+
"domain": row["domain"],
|
| 115 |
+
"name": row["name"],
|
| 116 |
+
"category": row["category"],
|
| 117 |
+
"prompts": prompt,
|
| 118 |
+
"wikipedia": row["wikipedia"],
|
| 119 |
+
"generated": full_text
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
return pd.DataFrame(flat_results)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def evaluate_generated_outputs(
|
| 126 |
+
table: pd.DataFrame,
|
| 127 |
+
device,
|
| 128 |
+
task: str = "sentiment",
|
| 129 |
+
toxicity_model_choice: str = "detoxify",
|
| 130 |
+
text_col: str = "generated",
|
| 131 |
+
) -> pd.DataFrame:
|
| 132 |
+
|
| 133 |
+
assert text_col in table.columns, f"'{text_col}' not found in table columns"
|
| 134 |
+
|
| 135 |
+
pipe_device = 0 if (isinstance(device, torch.device) and device.type == "cuda") else -1
|
| 136 |
+
|
| 137 |
+
df = table.copy()
|
| 138 |
+
texts = df[text_col].fillna("").astype(str).tolist()
|
| 139 |
+
|
| 140 |
+
task = (task or "sentiment").lower()
|
| 141 |
+
|
| 142 |
+
if task == "sentiment":
|
| 143 |
+
print("Using default sentiment classifier: lxyuan/distilbert-base-multilingual-cased-sentiments-student")
|
| 144 |
+
tok = AutoTokenizer.from_pretrained("lxyuan/distilbert-base-multilingual-cased-sentiments-student")
|
| 145 |
+
mdl = AutoModelForSequenceClassification.from_pretrained("lxyuan/distilbert-base-multilingual-cased-sentiments-student").to(device).eval()
|
| 146 |
+
|
| 147 |
+
scores = []
|
| 148 |
+
for text in tqdm(texts, desc="Scoring (sentiment)"):
|
| 149 |
+
if not text.strip():
|
| 150 |
+
scores.append(0.5)
|
| 151 |
+
continue
|
| 152 |
+
inputs = tok(text, return_tensors="pt", truncation=True, padding=True).to(device)
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
logits = mdl(**inputs).logits
|
| 155 |
+
probs = F.softmax(logits, dim=1).squeeze(0).tolist()
|
| 156 |
+
val = (probs[2] - probs[0] + 1.0) / 2.0
|
| 157 |
+
scores.append(float(val))
|
| 158 |
+
|
| 159 |
+
df["sentiment_score"] = scores
|
| 160 |
+
return df
|
| 161 |
+
|
| 162 |
+
elif task == "regard":
|
| 163 |
+
print("Using default regard classifier: sasha/regardv3")
|
| 164 |
+
clf = pipeline("text-classification", model="sasha/regardv3", device=pipe_device, top_k=None)
|
| 165 |
+
|
| 166 |
+
def _safe_relabel(text: str):
|
| 167 |
+
try:
|
| 168 |
+
out = clf(text)
|
| 169 |
+
if isinstance(out, list):
|
| 170 |
+
out = out[0] if out else {}
|
| 171 |
+
if isinstance(out, dict) and "label" in out and "score" in out:
|
| 172 |
+
return {out["label"].lower(): float(out["score"])}
|
| 173 |
+
if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
|
| 174 |
+
d = {}
|
| 175 |
+
for item in out:
|
| 176 |
+
d[item["label"].lower()] = float(item["score"])
|
| 177 |
+
return d
|
| 178 |
+
except Exception:
|
| 179 |
+
pass
|
| 180 |
+
return {"positive": 0.5, "negative": 0.5}
|
| 181 |
+
|
| 182 |
+
temp = []
|
| 183 |
+
for text in tqdm(texts, desc="Scoring (regard)"):
|
| 184 |
+
res = _safe_relabel(text)
|
| 185 |
+
pos = float(res.get("positive", 0.5))
|
| 186 |
+
neg = float(res.get("negative", 0.5))
|
| 187 |
+
temp.append(pos - neg + 1.0)
|
| 188 |
+
|
| 189 |
+
df["regard_score"] = temp
|
| 190 |
+
df["sentiment_score"] = df["regard_score"]
|
| 191 |
+
return df
|
| 192 |
+
|
| 193 |
+
elif task == "stereotype":
|
| 194 |
+
print("Using default stereotype classifier: holistic-ai/stereotype-deberta-v3-base-tasksource-nli")
|
| 195 |
+
clf = pipeline("text-classification", model="holistic-ai/stereotype-deberta-v3-base-tasksource-nli", device=pipe_device, top_k=None)
|
| 196 |
+
|
| 197 |
+
def _safe_relabel(text: str):
|
| 198 |
+
try:
|
| 199 |
+
out = clf(text)
|
| 200 |
+
if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
|
| 201 |
+
d = {}
|
| 202 |
+
for item in out:
|
| 203 |
+
d[item["label"].lower()] = float(item["score"])
|
| 204 |
+
return d
|
| 205 |
+
if isinstance(out, dict) and "label" in out:
|
| 206 |
+
return {out["label"].lower(): float(out.get("score", 0.0))}
|
| 207 |
+
except Exception:
|
| 208 |
+
pass
|
| 209 |
+
return {
|
| 210 |
+
"stereotype_gender": 0.0,
|
| 211 |
+
"stereotype_religion": 0.0,
|
| 212 |
+
"stereotype_profession": 0.0,
|
| 213 |
+
"stereotype_race": 0.0,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
g_list, r_list, p_list, race_list = [], [], [], []
|
| 217 |
+
for text in tqdm(texts, desc="Scoring (stereotype)"):
|
| 218 |
+
d = _safe_relabel(text)
|
| 219 |
+
g_list.append(float(d.get("stereotype_gender", 0.0)))
|
| 220 |
+
r_list.append(float(d.get("stereotype_religion", 0.0)))
|
| 221 |
+
p_list.append(float(d.get("stereotype_profession", 0.0)))
|
| 222 |
+
race_list.append(float(d.get("stereotype_race", 0.0)))
|
| 223 |
+
|
| 224 |
+
df["stereotype_gender_score"] = g_list
|
| 225 |
+
df["stereotype_religion_score"] = r_list
|
| 226 |
+
df["stereotype_profession_score"] = p_list
|
| 227 |
+
df["stereotype_race_score"] = race_list
|
| 228 |
+
|
| 229 |
+
df["sentiment_score"] = df["stereotype_gender_score"]
|
| 230 |
+
return df
|
| 231 |
+
|
| 232 |
+
elif task == "personality":
|
| 233 |
+
print("Using default personality classifier: Navya1602/editpersonality_classifier")
|
| 234 |
+
clf = pipeline("text-classification", model="Navya1602/editpersonality_classifier", device=pipe_device, top_k=None)
|
| 235 |
+
|
| 236 |
+
traits = ["extraversion", "neuroticism", "agreeableness", "conscientiousness", "openness"]
|
| 237 |
+
|
| 238 |
+
def _safe_relabel(text: str):
|
| 239 |
+
try:
|
| 240 |
+
out = clf(text)
|
| 241 |
+
if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
|
| 242 |
+
d = {}
|
| 243 |
+
for item in out:
|
| 244 |
+
d[item["label"].lower()] = float(item["score"])
|
| 245 |
+
return d
|
| 246 |
+
if isinstance(out, dict) and "label" in out:
|
| 247 |
+
return {out["label"].lower(): float(out.get("score", 0.0))}
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
return {t: 0.2 for t in traits}
|
| 251 |
+
|
| 252 |
+
cols = {t: [] for t in traits}
|
| 253 |
+
for text in tqdm(texts, desc="Scoring (personality)"):
|
| 254 |
+
d = _safe_relabel(text)
|
| 255 |
+
for t in traits:
|
| 256 |
+
cols[t].append(float(d.get(t, 0.2)))
|
| 257 |
+
|
| 258 |
+
for t in traits:
|
| 259 |
+
df[f"{t}_score"] = cols[t]
|
| 260 |
+
|
| 261 |
+
df["sentiment_score"] = df[[f"{t}_score" for t in traits]].mean(axis=1)
|
| 262 |
+
return df
|
| 263 |
+
|
| 264 |
+
elif task == "toxicity":
|
| 265 |
+
if toxicity_model_choice == "detoxify":
|
| 266 |
+
print("Using unitary/toxic-bert model for toxicity classification")
|
| 267 |
+
clf = pipeline("text-classification", model="unitary/toxic-bert", device=pipe_device, top_k=None)
|
| 268 |
+
def _get_toxic_prob(text: str) -> float:
|
| 269 |
+
try:
|
| 270 |
+
out = clf(text)
|
| 271 |
+
if isinstance(out, list) and out:
|
| 272 |
+
d = {it["label"].lower(): float(it["score"]) for it in out}
|
| 273 |
+
return float(d.get("toxic", d.get("toxic/overall", 0.0)))
|
| 274 |
+
if isinstance(out, dict) and "label" in out:
|
| 275 |
+
return float(out["score"]) if out["label"].lower() == "toxic" else 0.0
|
| 276 |
+
except Exception:
|
| 277 |
+
pass
|
| 278 |
+
return 0.0
|
| 279 |
+
elif toxicity_model_choice == "junglelee":
|
| 280 |
+
print("Using JungleLee/bert-toxic-comment-classification for toxicity classification")
|
| 281 |
+
clf = pipeline("text-classification", model="JungleLee/bert-toxic-comment-classification", device=pipe_device)
|
| 282 |
+
def _get_toxic_prob(text: str) -> float:
|
| 283 |
+
try:
|
| 284 |
+
out = clf(text)
|
| 285 |
+
if isinstance(out, dict):
|
| 286 |
+
lbl = out.get("label", "").lower()
|
| 287 |
+
score = float(out.get("score", 0.0))
|
| 288 |
+
return score if "toxic" in lbl else 0.0
|
| 289 |
+
if isinstance(out, list) and out:
|
| 290 |
+
for it in out:
|
| 291 |
+
if "toxic" in it.get("label", "").lower():
|
| 292 |
+
return float(it.get("score", 0.0))
|
| 293 |
+
except Exception:
|
| 294 |
+
pass
|
| 295 |
+
return 0.0
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError("Invalid toxicity_model_choice. Choose 'detoxify' or 'junglelee'.")
|
| 298 |
+
|
| 299 |
+
tox = []
|
| 300 |
+
for text in tqdm(texts, desc="Scoring (toxicity)"):
|
| 301 |
+
tox.append(_get_toxic_prob(text))
|
| 302 |
+
|
| 303 |
+
df["toxicity_score"] = tox
|
| 304 |
+
df["sentiment_score"] = df["toxicity_score"]
|
| 305 |
+
return df
|
| 306 |
+
|
| 307 |
+
else:
|
| 308 |
+
raise ValueError(f"Unknown task '{task}'. Use one of: sentiment | regard | stereotype | personality | toxicity")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
import numpy as np
|
| 312 |
+
import pandas as pd
|
| 313 |
+
from typing import List, Dict, Optional
|
| 314 |
+
|
| 315 |
+
def _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row):
|
| 316 |
+
categories = base_df[category_col].unique().tolist()
|
| 317 |
+
category_names = {}
|
| 318 |
+
|
| 319 |
+
for cat in categories:
|
| 320 |
+
category_names[cat] = base_df[base_df[category_col] == cat][name_col].unique().tolist()
|
| 321 |
+
|
| 322 |
+
print(f"Categories for CF generation: {[f'{cat}({len(names)})' for cat, names in category_names.items()]}")
|
| 323 |
+
|
| 324 |
+
cf_rows = []
|
| 325 |
+
for idx, row in base_df.iterrows():
|
| 326 |
+
original_text = row[text_col]
|
| 327 |
+
original_name = row[name_col]
|
| 328 |
+
original_category = row[category_col]
|
| 329 |
+
original_name_clean = original_name.replace("_", " ")
|
| 330 |
+
|
| 331 |
+
other_categories = [cat for cat in categories if cat != original_category]
|
| 332 |
+
|
| 333 |
+
for target_category in other_categories:
|
| 334 |
+
target_names = category_names[target_category]
|
| 335 |
+
|
| 336 |
+
if len(target_names) == 0:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
num_to_sample = min(num_cf_per_row // len(other_categories) + 1, len(target_names))
|
| 340 |
+
if num_to_sample == 0:
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
sampled_names = np.random.choice(target_names, size=num_to_sample, replace=False)
|
| 344 |
+
|
| 345 |
+
for new_name in sampled_names:
|
| 346 |
+
new_name_clean = new_name.replace("_", " ")
|
| 347 |
+
|
| 348 |
+
new_text = original_text.replace(original_name_clean, new_name_clean, 1)
|
| 349 |
+
|
| 350 |
+
if new_text == original_text:
|
| 351 |
+
original_parts = original_name_clean.split()
|
| 352 |
+
for part in original_parts:
|
| 353 |
+
if len(part) > 2:
|
| 354 |
+
new_text = original_text.replace(part, new_name_clean, 1)
|
| 355 |
+
if new_text != original_text:
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
+
if new_text == original_text:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
new_row = row.copy()
|
| 362 |
+
new_row[name_col] = new_name
|
| 363 |
+
new_row[text_col] = new_text
|
| 364 |
+
new_row[category_col] = target_category
|
| 365 |
+
new_row["original_category"] = original_category
|
| 366 |
+
new_row["cf_type"] = f"{original_category}->{target_category}"
|
| 367 |
+
cf_rows.append(new_row)
|
| 368 |
+
|
| 369 |
+
counterfactual_df = pd.DataFrame(cf_rows)
|
| 370 |
+
|
| 371 |
+
if len(counterfactual_df) > 0:
|
| 372 |
+
cf_stats = counterfactual_df["cf_type"].value_counts()
|
| 373 |
+
print(f"CF generation stats:")
|
| 374 |
+
for cf_type, count in cf_stats.items():
|
| 375 |
+
print(f" {cf_type}: {count}")
|
| 376 |
+
|
| 377 |
+
augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
|
| 378 |
+
|
| 379 |
+
print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
|
| 380 |
+
print(f"Total data len: {len(augmented_df)}")
|
| 381 |
+
|
| 382 |
+
return augmented_df
|
| 383 |
+
|
| 384 |
+
def auto_detect_cf_method(base_df, category_col="category"):
|
| 385 |
+
categories = set(base_df[category_col].unique())
|
| 386 |
+
|
| 387 |
+
if {"American_actors", "American_actresses"}.issubset(categories):
|
| 388 |
+
return "actors_actresses"
|
| 389 |
+
else:
|
| 390 |
+
return "cross_category"
|
| 391 |
+
|
| 392 |
+
class Tee:
|
| 393 |
+
def __init__(self, *streams):
|
| 394 |
+
self.streams = streams
|
| 395 |
+
def write(self, data):
|
| 396 |
+
for stream in self.streams:
|
| 397 |
+
stream.write(data)
|
| 398 |
+
stream.flush()
|
| 399 |
+
def flush(self):
|
| 400 |
+
for stream in self.streams:
|
| 401 |
+
stream.flush()
|
| 402 |
+
|
| 403 |
+
def generate_counterfactual_augmentations(base_df, text_col="generated", name_col="name", category_col="category", num_cf_per_row=3):
|
| 404 |
+
categories = base_df[category_col].unique().tolist()
|
| 405 |
+
category_names = {}
|
| 406 |
+
|
| 407 |
+
for cat in categories:
|
| 408 |
+
category_names[cat] = base_df[base_df[category_col] == cat][name_col].unique().tolist()
|
| 409 |
+
|
| 410 |
+
print(f"Categories for CF generation: {[f'{cat}({len(names)})' for cat, names in category_names.items()]}")
|
| 411 |
+
|
| 412 |
+
if "American_actors" in categories and "American_actresses" in categories:
|
| 413 |
+
return _generate_actors_actresses_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names)
|
| 414 |
+
else:
|
| 415 |
+
return _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names)
|
| 416 |
+
|
| 417 |
+
def _generate_actors_actresses_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names):
|
| 418 |
+
male_names = category_names.get("American_actors", [])
|
| 419 |
+
female_names = category_names.get("American_actresses", [])
|
| 420 |
+
|
| 421 |
+
cf_rows = []
|
| 422 |
+
for idx, row in base_df.iterrows():
|
| 423 |
+
original_text = row[text_col]
|
| 424 |
+
original_name = row[name_col]
|
| 425 |
+
category = row[category_col]
|
| 426 |
+
original_name_clean = original_name.replace("_", " ")
|
| 427 |
+
|
| 428 |
+
if category == "American_actors":
|
| 429 |
+
swap_pool = female_names
|
| 430 |
+
new_category = "American_actresses"
|
| 431 |
+
elif category == "American_actresses":
|
| 432 |
+
swap_pool = male_names
|
| 433 |
+
new_category = "American_actors"
|
| 434 |
+
else:
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
if len(swap_pool) == 0:
|
| 438 |
+
continue
|
| 439 |
+
|
| 440 |
+
sampled_names = np.random.choice(swap_pool, size=min(num_cf_per_row, len(swap_pool)), replace=False)
|
| 441 |
+
|
| 442 |
+
for new_name in sampled_names:
|
| 443 |
+
new_name_clean = new_name.replace("_", " ")
|
| 444 |
+
new_text = original_text.replace(original_name_clean, new_name_clean, 1)
|
| 445 |
+
|
| 446 |
+
if new_text == original_text:
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
new_row = row.copy()
|
| 450 |
+
new_row[name_col] = new_name
|
| 451 |
+
new_row[text_col] = new_text
|
| 452 |
+
new_row[category_col] = new_category
|
| 453 |
+
new_row["original_category"] = category
|
| 454 |
+
cf_rows.append(new_row)
|
| 455 |
+
|
| 456 |
+
counterfactual_df = pd.DataFrame(cf_rows)
|
| 457 |
+
augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
|
| 458 |
+
|
| 459 |
+
print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
|
| 460 |
+
print(f"Total data len: {len(augmented_df)}")
|
| 461 |
+
return augmented_df
|
| 462 |
+
|
| 463 |
+
def _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names):
|
| 464 |
+
categories = list(category_names.keys())
|
| 465 |
+
|
| 466 |
+
cf_rows = []
|
| 467 |
+
for idx, row in base_df.iterrows():
|
| 468 |
+
original_text = row[text_col]
|
| 469 |
+
original_name = row[name_col]
|
| 470 |
+
original_category = row[category_col]
|
| 471 |
+
original_name_clean = original_name.replace("_", " ")
|
| 472 |
+
|
| 473 |
+
other_categories = [cat for cat in categories if cat != original_category]
|
| 474 |
+
|
| 475 |
+
for target_category in other_categories:
|
| 476 |
+
target_names = category_names[target_category]
|
| 477 |
+
|
| 478 |
+
if len(target_names) == 0:
|
| 479 |
+
continue
|
| 480 |
+
|
| 481 |
+
num_to_sample = min(max(1, num_cf_per_row // len(other_categories)), len(target_names))
|
| 482 |
+
sampled_names = np.random.choice(target_names, size=num_to_sample, replace=False)
|
| 483 |
+
|
| 484 |
+
for new_name in sampled_names:
|
| 485 |
+
new_name_clean = new_name.replace("_", " ")
|
| 486 |
+
|
| 487 |
+
new_text = original_text.replace(original_name_clean, new_name_clean, 1)
|
| 488 |
+
|
| 489 |
+
if new_text == original_text:
|
| 490 |
+
original_parts = original_name_clean.split()
|
| 491 |
+
for part in original_parts:
|
| 492 |
+
if len(part) > 2:
|
| 493 |
+
new_text = original_text.replace(part, new_name_clean, 1)
|
| 494 |
+
if new_text != original_text:
|
| 495 |
+
break
|
| 496 |
+
|
| 497 |
+
if new_text == original_text:
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
new_row = row.copy()
|
| 501 |
+
new_row[name_col] = new_name
|
| 502 |
+
new_row[text_col] = new_text
|
| 503 |
+
new_row[category_col] = target_category
|
| 504 |
+
new_row["original_category"] = original_category
|
| 505 |
+
cf_rows.append(new_row)
|
| 506 |
+
|
| 507 |
+
counterfactual_df = pd.DataFrame(cf_rows)
|
| 508 |
+
augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
|
| 509 |
+
|
| 510 |
+
print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
|
| 511 |
+
print(f"Total data len: {len(augmented_df)}")
|
| 512 |
+
|
| 513 |
+
return augmented_df
|
| 514 |
+
|
| 515 |
+
def _ensure_plot_saved(
|
| 516 |
+
df,
|
| 517 |
+
score_col: str,
|
| 518 |
+
basename: str,
|
| 519 |
+
group_col: str = None,
|
| 520 |
+
target: float = None,
|
| 521 |
+
bins: int = 30,
|
| 522 |
+
) -> str:
|
| 523 |
+
os.makedirs("data", exist_ok=True)
|
| 524 |
+
path = os.path.join("data", f"{basename}.png")
|
| 525 |
+
|
| 526 |
+
plt.figure(figsize=(8, 5))
|
| 527 |
+
data = df[score_col].dropna().values
|
| 528 |
+
|
| 529 |
+
if group_col and group_col in df.columns:
|
| 530 |
+
for g, sub in df.groupby(group_col):
|
| 531 |
+
vals = sub[score_col].dropna().values
|
| 532 |
+
if len(vals) == 0:
|
| 533 |
+
continue
|
| 534 |
+
plt.hist(vals, bins=bins, alpha=0.4, label=f"{g} (n={len(vals)}, μ={np.mean(vals):.3f})", density=True)
|
| 535 |
+
else:
|
| 536 |
+
plt.hist(data, bins=bins, alpha=0.6, density=True, label=f"All (n={len(data)}, μ={np.mean(data):.3f})")
|
| 537 |
+
|
| 538 |
+
if len(data):
|
| 539 |
+
m = float(np.mean(data))
|
| 540 |
+
plt.axvline(m, linestyle="--", linewidth=2, label=f"mean={m:.3f}")
|
| 541 |
+
|
| 542 |
+
if target is not None:
|
| 543 |
+
plt.axvline(target, linestyle="-.", linewidth=2, label=f"target={target:.3f}")
|
| 544 |
+
|
| 545 |
+
plt.xlabel(score_col)
|
| 546 |
+
plt.ylabel("density")
|
| 547 |
+
plt.title(basename.replace("_", " "))
|
| 548 |
+
plt.legend(loc="best")
|
| 549 |
+
plt.tight_layout()
|
| 550 |
+
plt.savefig(path, dpi=160)
|
| 551 |
+
plt.close()
|
| 552 |
+
return path
|
frontend/.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules
|
| 11 |
+
dist
|
| 12 |
+
dist-ssr
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/*
|
| 17 |
+
!.vscode/extensions.json
|
| 18 |
+
.idea
|
| 19 |
+
.DS_Store
|
| 20 |
+
*.suo
|
| 21 |
+
*.ntvs*
|
| 22 |
+
*.njsproj
|
| 23 |
+
*.sln
|
| 24 |
+
*.sw?
|
frontend/README.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# React + TypeScript + Vite
|
| 2 |
+
|
| 3 |
+
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
| 4 |
+
|
| 5 |
+
Currently, two official plugins are available:
|
| 6 |
+
|
| 7 |
+
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
|
| 8 |
+
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
| 9 |
+
|
| 10 |
+
## Expanding the ESLint configuration
|
| 11 |
+
|
| 12 |
+
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
| 13 |
+
|
| 14 |
+
```js
|
| 15 |
+
export default tseslint.config([
|
| 16 |
+
globalIgnores(['dist']),
|
| 17 |
+
{
|
| 18 |
+
files: ['**/*.{ts,tsx}'],
|
| 19 |
+
extends: [
|
| 20 |
+
// Other configs...
|
| 21 |
+
|
| 22 |
+
// Remove tseslint.configs.recommended and replace with this
|
| 23 |
+
...tseslint.configs.recommendedTypeChecked,
|
| 24 |
+
// Alternatively, use this for stricter rules
|
| 25 |
+
...tseslint.configs.strictTypeChecked,
|
| 26 |
+
// Optionally, add this for stylistic rules
|
| 27 |
+
...tseslint.configs.stylisticTypeChecked,
|
| 28 |
+
|
| 29 |
+
// Other configs...
|
| 30 |
+
],
|
| 31 |
+
languageOptions: {
|
| 32 |
+
parserOptions: {
|
| 33 |
+
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
| 34 |
+
tsconfigRootDir: import.meta.dirname,
|
| 35 |
+
},
|
| 36 |
+
// other options...
|
| 37 |
+
},
|
| 38 |
+
},
|
| 39 |
+
])
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
| 43 |
+
|
| 44 |
+
```js
|
| 45 |
+
// eslint.config.js
|
| 46 |
+
import reactX from 'eslint-plugin-react-x'
|
| 47 |
+
import reactDom from 'eslint-plugin-react-dom'
|
| 48 |
+
|
| 49 |
+
export default tseslint.config([
|
| 50 |
+
globalIgnores(['dist']),
|
| 51 |
+
{
|
| 52 |
+
files: ['**/*.{ts,tsx}'],
|
| 53 |
+
extends: [
|
| 54 |
+
// Other configs...
|
| 55 |
+
// Enable lint rules for React
|
| 56 |
+
reactX.configs['recommended-typescript'],
|
| 57 |
+
// Enable lint rules for React DOM
|
| 58 |
+
reactDom.configs.recommended,
|
| 59 |
+
],
|
| 60 |
+
languageOptions: {
|
| 61 |
+
parserOptions: {
|
| 62 |
+
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
| 63 |
+
tsconfigRootDir: import.meta.dirname,
|
| 64 |
+
},
|
| 65 |
+
// other options...
|
| 66 |
+
},
|
| 67 |
+
},
|
| 68 |
+
])
|
| 69 |
+
```
|
frontend/eslint.config.js
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import tseslint from 'typescript-eslint'
|
| 6 |
+
import { globalIgnores } from 'eslint/config'
|
| 7 |
+
|
| 8 |
+
export default tseslint.config([
|
| 9 |
+
globalIgnores(['dist']),
|
| 10 |
+
{
|
| 11 |
+
files: ['**/*.{ts,tsx}'],
|
| 12 |
+
extends: [
|
| 13 |
+
js.configs.recommended,
|
| 14 |
+
tseslint.configs.recommended,
|
| 15 |
+
reactHooks.configs['recommended-latest'],
|
| 16 |
+
reactRefresh.configs.vite,
|
| 17 |
+
],
|
| 18 |
+
languageOptions: {
|
| 19 |
+
ecmaVersion: 2020,
|
| 20 |
+
globals: globals.browser,
|
| 21 |
+
},
|
| 22 |
+
},
|
| 23 |
+
])
|
frontend/index.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>Vite + React + TS</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "frontend",
|
| 3 |
+
"private": true,
|
| 4 |
+
"version": "0.0.0",
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "vite",
|
| 8 |
+
"build": "vite build",
|
| 9 |
+
"lint": "eslint .",
|
| 10 |
+
"preview": "vite preview"
|
| 11 |
+
},
|
| 12 |
+
"dependencies": {
|
| 13 |
+
"lucide-react": "^0.542.0",
|
| 14 |
+
"react": "^19.1.1",
|
| 15 |
+
"react-dom": "^19.1.1",
|
| 16 |
+
"recharts": "^3.1.2"
|
| 17 |
+
},
|
| 18 |
+
"devDependencies": {
|
| 19 |
+
"@eslint/js": "^9.33.0",
|
| 20 |
+
"@tailwindcss/forms": "^0.5.10",
|
| 21 |
+
"@types/react": "^19.1.10",
|
| 22 |
+
"@types/react-dom": "^19.1.7",
|
| 23 |
+
"@vitejs/plugin-react": "^5.0.0",
|
| 24 |
+
"autoprefixer": "^10.4.21",
|
| 25 |
+
"eslint": "^9.33.0",
|
| 26 |
+
"eslint-plugin-react-hooks": "^5.2.0",
|
| 27 |
+
"eslint-plugin-react-refresh": "^0.4.20",
|
| 28 |
+
"globals": "^16.3.0",
|
| 29 |
+
"postcss": "^8.5.6",
|
| 30 |
+
"tailwindcss": "^3.4.17",
|
| 31 |
+
"typescript": "~5.8.3",
|
| 32 |
+
"typescript-eslint": "^8.39.1",
|
| 33 |
+
"vite": "^7.1.2"
|
| 34 |
+
}
|
| 35 |
+
}
|
frontend/postcss.config.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default {
|
| 2 |
+
plugins: {
|
| 3 |
+
tailwindcss: {},
|
| 4 |
+
autoprefixer: {},
|
| 5 |
+
},
|
| 6 |
+
}
|
frontend/public/vite.svg
ADDED
|
|
frontend/src/App.css
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#root {
|
| 2 |
+
max-width: 1280px;
|
| 3 |
+
margin: 0 auto;
|
| 4 |
+
padding: 2rem;
|
| 5 |
+
text-align: center;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
.logo {
|
| 9 |
+
height: 6em;
|
| 10 |
+
padding: 1.5em;
|
| 11 |
+
will-change: filter;
|
| 12 |
+
transition: filter 300ms;
|
| 13 |
+
}
|
| 14 |
+
.logo:hover {
|
| 15 |
+
filter: drop-shadow(0 0 2em #646cffaa);
|
| 16 |
+
}
|
| 17 |
+
.logo.react:hover {
|
| 18 |
+
filter: drop-shadow(0 0 2em #61dafbaa);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@keyframes logo-spin {
|
| 22 |
+
from {
|
| 23 |
+
transform: rotate(0deg);
|
| 24 |
+
}
|
| 25 |
+
to {
|
| 26 |
+
transform: rotate(360deg);
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@media (prefers-reduced-motion: no-preference) {
|
| 31 |
+
a:nth-of-type(2) .logo {
|
| 32 |
+
animation: logo-spin infinite 20s linear;
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
.card {
|
| 37 |
+
padding: 2em;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.read-the-docs {
|
| 41 |
+
color: #888;
|
| 42 |
+
}
|
frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, Suspense, lazy } from 'react';
|
| 2 |
+
import { Target } from 'lucide-react';
|
| 3 |
+
import type { JobConfig, Extras } from './types';
|
| 4 |
+
import { JobRunnerProvider, useJobRunner } from './hooks/JobRunnerProvider';
|
| 5 |
+
|
| 6 |
+
const ConfigPage = lazy(() => import('./pages/ConfigPage'));
|
| 7 |
+
const ResultsPage = lazy(() => import('./pages/ResultsPage'));
|
| 8 |
+
|
| 9 |
+
type Tab = 'config'|'results'|'reports';
|
| 10 |
+
|
| 11 |
+
function AppInner() {
|
| 12 |
+
const [tab, setTab] = useState<Tab>('config');
|
| 13 |
+
const { start } = useJobRunner();
|
| 14 |
+
|
| 15 |
+
const run = (cfg: JobConfig, extras: Extras) => {
|
| 16 |
+
start(cfg, extras);
|
| 17 |
+
setTab('results');
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
const tabBtn = (active: boolean) =>
|
| 21 |
+
`relative px-4 py-2 rounded-xl text-sm font-medium transition-all ${
|
| 22 |
+
active
|
| 23 |
+
? 'text-white bg-gradient-to-r from-indigo-600 via-violet-600 to-fuchsia-600 shadow-lg shadow-indigo-600/20'
|
| 24 |
+
: 'text-slate-600 hover:text-slate-800 bg-white/70 backdrop-blur border border-white/30 hover:shadow-md'
|
| 25 |
+
}`;
|
| 26 |
+
|
| 27 |
+
return (
|
| 28 |
+
<div className="min-h-screen bg-[radial-gradient(1200px_800px_at_20%_-10%,#c7d2fe_0%,transparent_60%),radial-gradient(1200px_800px_at_120%_20%,#fbcfe8_0%,transparent_55%)] bg-slate-50">
|
| 29 |
+
<header className="sticky top-0 z-50">
|
| 30 |
+
<div className="backdrop-blur-xl bg-white/60 border-b border-white/40 shadow-[0_10px_30px_-12px_rgba(30,41,59,0.25)] supports-[backdrop-filter]:bg-white/40">
|
| 31 |
+
<div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between">
|
| 32 |
+
<div className="flex items-center gap-3">
|
| 33 |
+
<div className="p-2 rounded-xl bg-gradient-to-br from-indigo-600 to-fuchsia-600 shadow-md shadow-indigo-600/30">
|
| 34 |
+
<Target className="w-6 h-6 text-white" />
|
| 35 |
+
</div>
|
| 36 |
+
<div>
|
| 37 |
+
<h1 className="text-lg sm:text-xl font-bold tracking-tight text-slate-900">AAAI Demo</h1>
|
| 38 |
+
<p className="text-xs text-slate-600">Rank Sampling</p>
|
| 39 |
+
</div>
|
| 40 |
+
</div>
|
| 41 |
+
|
| 42 |
+
<div className="hidden sm:flex items-center gap-2">
|
| 43 |
+
<span className="px-3 py-1 rounded-full text-xs font-medium bg-emerald-50 text-emerald-700 border border-emerald-200">System Health</span>
|
| 44 |
+
<span className="px-3 py-1 rounded-full text-xs font-medium bg-slate-100 text-slate-700 border border-slate-200">v1.0.0</span>
|
| 45 |
+
</div>
|
| 46 |
+
</div>
|
| 47 |
+
</div>
|
| 48 |
+
|
| 49 |
+
<div className="border-b border-white/40 bg-white/30 backdrop-blur supports-[backdrop-filter]:bg-white/20">
|
| 50 |
+
<div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-3 flex gap-3">
|
| 51 |
+
<button className={tabBtn(tab==='config')} onClick={() => setTab('config')}>Config Setting</button>
|
| 52 |
+
<button className={tabBtn(tab==='results')} onClick={() => setTab('results')}>Results</button>
|
| 53 |
+
</div>
|
| 54 |
+
</div>
|
| 55 |
+
</header>
|
| 56 |
+
|
| 57 |
+
<main className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8">
|
| 58 |
+
<Suspense fallback={
|
| 59 |
+
<div className="p-8 rounded-2xl border border-white/30 bg-white/60 backdrop-blur text-slate-600">
|
| 60 |
+
Loading
|
| 61 |
+
</div>
|
| 62 |
+
}>
|
| 63 |
+
{/* 改成把 extras 一起傳 */}
|
| 64 |
+
{tab === 'config' && <ConfigPage onRun={run} />}
|
| 65 |
+
{tab === 'results' && <ResultsPage />}
|
| 66 |
+
</Suspense>
|
| 67 |
+
</main>
|
| 68 |
+
</div>
|
| 69 |
+
);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
export default function App() {
|
| 73 |
+
return (
|
| 74 |
+
<JobRunnerProvider>
|
| 75 |
+
<AppInner />
|
| 76 |
+
</JobRunnerProvider>
|
| 77 |
+
);
|
| 78 |
+
}
|
frontend/src/assets/react.svg
ADDED
|
|
frontend/src/components/MetricCard.tsx
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/components/MetricCard.tsx
|
| 2 |
+
import React from 'react';
|
| 3 |
+
import type { ReactNode } from 'react';
|
| 4 |
+
import { TrendingDown, TrendingUp } from 'lucide-react';
|
| 5 |
+
|
| 6 |
+
export default function MetricCard({
|
| 7 |
+
title, value, change, positive, icon, description
|
| 8 |
+
}: {
|
| 9 |
+
title: string; value: string; change?: string; positive?: boolean; icon: ReactNode; description?: string;
|
| 10 |
+
}) {
|
| 11 |
+
return (
|
| 12 |
+
<div className="bg-white rounded-lg shadow-sm border p-6 hover:shadow-md transition-shadow">
|
| 13 |
+
<div className="flex items-center justify-between">
|
| 14 |
+
<div className="flex-1">
|
| 15 |
+
<p className="text-sm font-medium text-gray-600">{title}</p>
|
| 16 |
+
<p className="text-2xl font-bold text-gray-900 mt-1">{value}</p>
|
| 17 |
+
{change && (
|
| 18 |
+
<div className="flex items-center mt-2">
|
| 19 |
+
{positive ? <TrendingDown className="w-4 h-4 text-green-500 mr-1" /> : <TrendingUp className="w-4 h-4 text-red-500 mr-1" />}
|
| 20 |
+
<span className={`text-sm font-medium ${positive ? 'text-green-600' : 'text-red-600'}`}>{change}</span>
|
| 21 |
+
</div>
|
| 22 |
+
)}
|
| 23 |
+
{description && <p className="text-xs text-gray-500 mt-1">{description}</p>}
|
| 24 |
+
</div>
|
| 25 |
+
<div className="p-3 bg-blue-50 rounded-full">{icon}</div>
|
| 26 |
+
</div>
|
| 27 |
+
</div>
|
| 28 |
+
);
|
| 29 |
+
}
|
frontend/src/components/PipelineProgress.tsx
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useRef, useState } from 'react';
|
| 2 |
+
import { CheckCircle2, Loader2, Database, Brain, Sparkles, Rocket, LineChart } from 'lucide-react';
|
| 3 |
+
import { MLBiasAPI } from '../services/api';
|
| 4 |
+
import { useJobRunner } from '../hooks/JobRunnerProvider';
|
| 5 |
+
|
| 6 |
+
type Health = {
|
| 7 |
+
status?: string;
|
| 8 |
+
timestamp?: string;
|
| 9 |
+
loaded_models?: string[];
|
| 10 |
+
dataset_loaded?: boolean;
|
| 11 |
+
generation_results_available?: boolean;
|
| 12 |
+
finetune_running?: boolean;
|
| 13 |
+
steps?: Record<string, boolean | 'todo' | 'doing' | 'done'>;
|
| 14 |
+
};
|
| 15 |
+
|
| 16 |
+
type StepKey =
|
| 17 |
+
| 'Activate Task'
|
| 18 |
+
| 'Load Dataset'
|
| 19 |
+
| 'Load Model'
|
| 20 |
+
| 'Generate and Score'
|
| 21 |
+
| 'Counterfactual'
|
| 22 |
+
| 'Sampling'
|
| 23 |
+
| 'Plot and Output'
|
| 24 |
+
| 'Finetune';
|
| 25 |
+
|
| 26 |
+
type StepState = 'todo' | 'doing' | 'done';
|
| 27 |
+
|
| 28 |
+
export default function PipelineProgress() {
|
| 29 |
+
const { result, resp } = useJobRunner();
|
| 30 |
+
|
| 31 |
+
const [health, setHealth] = useState<Health | null>(null);
|
| 32 |
+
const pollRef = useRef<number | null>(null);
|
| 33 |
+
|
| 34 |
+
useEffect(() => {
|
| 35 |
+
const poll = async () => {
|
| 36 |
+
try {
|
| 37 |
+
const h = (await MLBiasAPI.checkHealth()) as Health;
|
| 38 |
+
setHealth(prev => (JSON.stringify(prev) === JSON.stringify(h) ? prev : h));
|
| 39 |
+
} catch {
|
| 40 |
+
}
|
| 41 |
+
};
|
| 42 |
+
void poll();
|
| 43 |
+
pollRef.current = window.setInterval(poll, 1000);
|
| 44 |
+
return () => {
|
| 45 |
+
if (pollRef.current) window.clearInterval(pollRef.current);
|
| 46 |
+
};
|
| 47 |
+
}, []);
|
| 48 |
+
|
| 49 |
+
const [elapsed, setElapsed] = useState<number>(0);
|
| 50 |
+
const timerRef = useRef<number | null>(null);
|
| 51 |
+
useEffect(() => {
|
| 52 |
+
const startedAt = Date.now();
|
| 53 |
+
timerRef.current = window.setInterval(() => {
|
| 54 |
+
setElapsed(Math.floor((Date.now() - startedAt) / 1000));
|
| 55 |
+
}, 1000);
|
| 56 |
+
return () => {
|
| 57 |
+
if (timerRef.current) window.clearInterval(timerRef.current);
|
| 58 |
+
};
|
| 59 |
+
}, []);
|
| 60 |
+
|
| 61 |
+
const modelName = result?.config?.languageModel || '';
|
| 62 |
+
|
| 63 |
+
const wantFT = Boolean(
|
| 64 |
+
result?.config?.enableFineTuning ?? (resp?.results as any)?.config_used?.enableFineTuning
|
| 65 |
+
);
|
| 66 |
+
|
| 67 |
+
const backendSteps = useMemo(() => {
|
| 68 |
+
const fromResp = ((resp?.results as any)?.steps || {}) as Record<string, boolean>;
|
| 69 |
+
const fromHealth = ((health?.steps || {}) as Record<string, boolean | 'todo' | 'doing' | 'done'>);
|
| 70 |
+
const merged: Record<string, boolean> = { ...fromResp };
|
| 71 |
+
Object.keys(fromHealth).forEach(k => {
|
| 72 |
+
const v = (fromHealth as any)[k];
|
| 73 |
+
merged[k] = v === true || v === 'doing' || v === 'done';
|
| 74 |
+
});
|
| 75 |
+
return merged;
|
| 76 |
+
}, [health?.steps, resp?.results]);
|
| 77 |
+
|
| 78 |
+
const resultsAny = (resp?.results ?? {}) as any;
|
| 79 |
+
|
| 80 |
+
const inferred = useMemo(() => {
|
| 81 |
+
const hasData = Boolean(health?.dataset_loaded);
|
| 82 |
+
const hasModel = Boolean(health?.loaded_models && health.loaded_models.length > 0);
|
| 83 |
+
|
| 84 |
+
const genDone = Boolean(
|
| 85 |
+
backendSteps['3_generate_and_eval'] ||
|
| 86 |
+
health?.generation_results_available ||
|
| 87 |
+
resultsAny.generation_done
|
| 88 |
+
);
|
| 89 |
+
|
| 90 |
+
const r4Flag = Boolean(
|
| 91 |
+
backendSteps['4_rank_sampling_original'] ||
|
| 92 |
+
resultsAny.rank_sampling_original_done ||
|
| 93 |
+
resultsAny.rank_sampling?.original_done
|
| 94 |
+
);
|
| 95 |
+
|
| 96 |
+
const r5Flag = Boolean(
|
| 97 |
+
backendSteps['5_rank_sampling_cf'] ||
|
| 98 |
+
resultsAny.rank_sampling_cf_done ||
|
| 99 |
+
resultsAny.rank_sampling?.cf_done
|
| 100 |
+
);
|
| 101 |
+
|
| 102 |
+
const plotsFlag = Boolean(
|
| 103 |
+
backendSteps['6_plots_and_metrics'] ||
|
| 104 |
+
resultsAny.plot_urls ||
|
| 105 |
+
resultsAny.plots_ready ||
|
| 106 |
+
(resultsAny.plots &&
|
| 107 |
+
(resultsAny.plots.original_sentiment || resultsAny.plots.counterfactual_sentiment))
|
| 108 |
+
);
|
| 109 |
+
|
| 110 |
+
const ftDoneFlag = Boolean(
|
| 111 |
+
backendSteps['7_finetune'] === true ||
|
| 112 |
+
resultsAny.finetune_done ||
|
| 113 |
+
resultsAny.finetune?.completed ||
|
| 114 |
+
resultsAny.finetune?.saved_model_path
|
| 115 |
+
);
|
| 116 |
+
|
| 117 |
+
const ftRunning = Boolean(resultsAny.finetune?.running || (health as any)?.finetune_running);
|
| 118 |
+
|
| 119 |
+
const noStepSignals =
|
| 120 |
+
Object.keys(backendSteps || {}).length === 0 &&
|
| 121 |
+
!resultsAny.rank_sampling_original_done &&
|
| 122 |
+
!resultsAny.rank_sampling_cf_done &&
|
| 123 |
+
!resultsAny.plots_ready &&
|
| 124 |
+
!resultsAny.finetune_done;
|
| 125 |
+
|
| 126 |
+
const cfByTime = noStepSignals && genDone && elapsed > 30;
|
| 127 |
+
const rsByTime = noStepSignals && genDone && elapsed > 45;
|
| 128 |
+
const plotsByTime= noStepSignals && genDone && elapsed > 70;
|
| 129 |
+
|
| 130 |
+
const cfDone = Boolean(
|
| 131 |
+
backendSteps['3_5_counterfactual'] ||
|
| 132 |
+
resultsAny.counterfactual_done ||
|
| 133 |
+
resultsAny.counterfactual_results ||
|
| 134 |
+
r4Flag || r5Flag || plotsFlag || ftDoneFlag ||
|
| 135 |
+
cfByTime
|
| 136 |
+
);
|
| 137 |
+
|
| 138 |
+
const r4 = r4Flag || rsByTime;
|
| 139 |
+
const r5 = r5Flag || rsByTime;
|
| 140 |
+
const plots = plotsFlag || plotsByTime;
|
| 141 |
+
const ftDone = ftDoneFlag;
|
| 142 |
+
|
| 143 |
+
return { hasData, hasModel, genDone, cfDone, r4, r5, plots, ftDone, ftRunning };
|
| 144 |
+
}, [backendSteps, health, resultsAny, elapsed]);
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
const rawSteps = useMemo<Record<StepKey, StepState>>(() => {
|
| 148 |
+
const states: Record<StepKey, StepState> = {
|
| 149 |
+
'Activate Task': 'todo',
|
| 150 |
+
'Load Dataset': 'todo',
|
| 151 |
+
'Load Model': 'todo',
|
| 152 |
+
'Generate and Score': 'todo',
|
| 153 |
+
'Counterfactual': 'todo',
|
| 154 |
+
'Sampling': 'todo',
|
| 155 |
+
'Plot and Output': 'todo',
|
| 156 |
+
'Finetune': 'todo',
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
if (result?.status === 'running') {
|
| 160 |
+
states['Activate Task'] = 'doing';
|
| 161 |
+
}
|
| 162 |
+
if (inferred.hasData) {
|
| 163 |
+
states['Activate Task'] = 'done';
|
| 164 |
+
states['Load Dataset'] = 'done';
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
if (inferred.hasModel) {
|
| 168 |
+
states['Load Model'] = 'done';
|
| 169 |
+
} else if (inferred.hasData) {
|
| 170 |
+
states['Load Model'] = 'doing';
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if (inferred.genDone) {
|
| 174 |
+
states['Generate and Score'] = 'done';
|
| 175 |
+
} else if (inferred.hasModel) {
|
| 176 |
+
states['Generate and Score'] = 'doing';
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
if (inferred.cfDone) {
|
| 180 |
+
states['Counterfactual'] = 'done';
|
| 181 |
+
} else if (states['Generate and Score'] === 'done') {
|
| 182 |
+
states['Counterfactual'] = 'doing';
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
const shouldStartSampling =
|
| 186 |
+
inferred.r4 || inferred.r5 ||
|
| 187 |
+
states['Counterfactual'] === 'done' ||
|
| 188 |
+
(states['Generate and Score'] === 'done' && elapsed > 20);
|
| 189 |
+
|
| 190 |
+
if (inferred.r4 && inferred.r5) {
|
| 191 |
+
states['Sampling'] = 'done';
|
| 192 |
+
} else if (shouldStartSampling) {
|
| 193 |
+
states['Sampling'] = 'doing';
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
const shouldStartPlotting =
|
| 197 |
+
inferred.plots ||
|
| 198 |
+
states['Sampling'] === 'done' ||
|
| 199 |
+
(states['Sampling'] === 'doing' && elapsed > 40);
|
| 200 |
+
|
| 201 |
+
if (inferred.plots) {
|
| 202 |
+
states['Plot and Output'] = 'done';
|
| 203 |
+
} else if (shouldStartPlotting) {
|
| 204 |
+
states['Plot and Output'] = 'doing';
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
if (wantFT) {
|
| 208 |
+
if (inferred.ftDone) states['Finetune'] = 'done';
|
| 209 |
+
else if (inferred.ftRunning || states['Plot and Output'] === 'done')
|
| 210 |
+
states['Finetune'] = 'doing';
|
| 211 |
+
else states['Finetune'] = 'todo';
|
| 212 |
+
} else {
|
| 213 |
+
states['Finetune'] = 'todo';
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return states;
|
| 217 |
+
}, [elapsed, inferred, wantFT, result?.status]);
|
| 218 |
+
|
| 219 |
+
const STUCK_TIMEOUT = 30; // 秒
|
| 220 |
+
const [enteredAt, setEnteredAt] = useState<Record<StepKey, number>>({} as any);
|
| 221 |
+
const [forcedDone, setForcedDone] = useState<Record<StepKey, boolean>>({} as any);
|
| 222 |
+
|
| 223 |
+
useEffect(() => {
|
| 224 |
+
const next: Record<StepKey, number> = { ...enteredAt } as any;
|
| 225 |
+
(Object.keys(rawSteps) as StepKey[]).forEach((k) => {
|
| 226 |
+
if (rawSteps[k] === 'doing' && !next[k]) next[k] = Date.now();
|
| 227 |
+
if (rawSteps[k] !== 'doing' && next[k]) delete next[k];
|
| 228 |
+
});
|
| 229 |
+
if (JSON.stringify(next) !== JSON.stringify(enteredAt)) setEnteredAt(next);
|
| 230 |
+
}, [rawSteps]);
|
| 231 |
+
|
| 232 |
+
useEffect(() => {
|
| 233 |
+
const now = Date.now();
|
| 234 |
+
const k: StepKey = 'Counterfactual';
|
| 235 |
+
if (rawSteps[k] === 'doing' && enteredAt[k] && now - enteredAt[k] > STUCK_TIMEOUT * 1000) {
|
| 236 |
+
if (!forcedDone[k]) setForcedDone(prev => ({ ...prev, [k]: true }));
|
| 237 |
+
}
|
| 238 |
+
}, [enteredAt, rawSteps, forcedDone]);
|
| 239 |
+
|
| 240 |
+
const steps = useMemo(() => {
|
| 241 |
+
const s = { ...rawSteps } as Record<StepKey, StepState>;
|
| 242 |
+
(Object.keys(forcedDone) as StepKey[]).forEach((k) => {
|
| 243 |
+
if (forcedDone[k]) s[k] = 'done';
|
| 244 |
+
});
|
| 245 |
+
return s;
|
| 246 |
+
}, [rawSteps, forcedDone]);
|
| 247 |
+
|
| 248 |
+
const ft = resultsAny?.finetune || {};
|
| 249 |
+
const downloadPath: string | undefined =
|
| 250 |
+
ft.download_url || ft.model_url || ft.saved_model_path || resultsAny?.finetune_model_url;
|
| 251 |
+
|
| 252 |
+
const downloadHref = downloadPath ? MLBiasAPI.resolvePath(downloadPath) : undefined;
|
| 253 |
+
|
| 254 |
+
const baseSteps = [
|
| 255 |
+
{ key: 'Activate Task', icon: Rocket },
|
| 256 |
+
{ key: 'Load Dataset', icon: Database },
|
| 257 |
+
{ key: 'Load Model', icon: Brain },
|
| 258 |
+
{ key: 'Generate and Score', icon: Sparkles },
|
| 259 |
+
{ key: 'Counterfactual', icon: Sparkles },
|
| 260 |
+
{ key: 'Sampling', icon: LineChart },
|
| 261 |
+
{ key: 'Plot and Output', icon: LineChart },
|
| 262 |
+
] as const;
|
| 263 |
+
|
| 264 |
+
const stepList = wantFT
|
| 265 |
+
? [...baseSteps, { key: 'Finetune', icon: Rocket } as const]
|
| 266 |
+
: baseSteps;
|
| 267 |
+
|
| 268 |
+
const completedCount = stepList.reduce(
|
| 269 |
+
(acc, s) => acc + (steps[s.key as StepKey] === 'done' ? 1 : 0),
|
| 270 |
+
0
|
| 271 |
+
);
|
| 272 |
+
const doingCount = stepList.reduce(
|
| 273 |
+
(acc, s) => acc + (steps[s.key as StepKey] === 'doing' ? 1 : 0),
|
| 274 |
+
0
|
| 275 |
+
);
|
| 276 |
+
const percent = Math.min(
|
| 277 |
+
100,
|
| 278 |
+
Math.round(((completedCount + doingCount * 0.5) / stepList.length) * 100)
|
| 279 |
+
);
|
| 280 |
+
|
| 281 |
+
const hasStuckStep =
|
| 282 |
+
Object.values(steps).some((state) => state === 'doing') &&
|
| 283 |
+
elapsed > 60 &&
|
| 284 |
+
completedCount < stepList.length - 1;
|
| 285 |
+
|
| 286 |
+
return (
|
| 287 |
+
<div className="relative overflow-hidden rounded-2xl border border-indigo-200/50 bg-white/70 backdrop-blur">
|
| 288 |
+
<div className="absolute inset-0 -z-10 bg-[radial-gradient(800px_300px_at_20%_-20%,rgba(99,102,241,0.15),transparent_60%),radial-gradient(800px_300px_at_120%_0%,rgba(244,114,182,0.15),transparent_60%)]" />
|
| 289 |
+
<div className="p-6">
|
| 290 |
+
<div className="flex items-center justify-between mb-4">
|
| 291 |
+
<div>
|
| 292 |
+
<h3 className="text-lg font-semibold text-slate-900">Pipeline Running</h3>
|
| 293 |
+
<p className="text-sm text-slate-600">
|
| 294 |
+
Model: <span className="font-medium text-slate-800">{modelName || '(未指定)'}</span>
|
| 295 |
+
</p>
|
| 296 |
+
{hasStuckStep && (
|
| 297 |
+
<p className="text-xs text-amber-600 mt-1">⚠️ Some steps may run slowly and are automatically attempted to proceed.</p>
|
| 298 |
+
)}
|
| 299 |
+
</div>
|
| 300 |
+
<div className="flex items-center gap-2 text-slate-600">
|
| 301 |
+
<Loader2 className="w-5 h-5 animate-spin" />
|
| 302 |
+
<span className="text-sm">Executed {elapsed}s</span>
|
| 303 |
+
</div>
|
| 304 |
+
</div>
|
| 305 |
+
|
| 306 |
+
<div className="w-full h-3 rounded-full bg-slate-200 overflow-hidden">
|
| 307 |
+
<div
|
| 308 |
+
className="h-3 bg-gradient-to-r from-indigo-500 via-violet-500 to-fuchsia-500 transition-all duration-500"
|
| 309 |
+
style={{ width: `${percent}%` }}
|
| 310 |
+
/>
|
| 311 |
+
</div>
|
| 312 |
+
|
| 313 |
+
<ol className="mt-6 grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3">
|
| 314 |
+
{stepList.map(({ key, icon: Icon }) => {
|
| 315 |
+
const state = steps[key as StepKey];
|
| 316 |
+
const isDone = state === 'done';
|
| 317 |
+
const isDoing = state === 'doing';
|
| 318 |
+
const startTs = enteredAt[key as StepKey];
|
| 319 |
+
const isStuck = isDoing && startTs && (Date.now() - startTs) / 1000 > STUCK_TIMEOUT;
|
| 320 |
+
|
| 321 |
+
return (
|
| 322 |
+
<li
|
| 323 |
+
key={key}
|
| 324 |
+
className={`flex items-center gap-3 rounded-xl border p-3 transition-all duration-300 ${
|
| 325 |
+
isDone
|
| 326 |
+
? 'border-emerald-200 bg-emerald-50'
|
| 327 |
+
: isDoing
|
| 328 |
+
? isStuck
|
| 329 |
+
? 'border-amber-300 bg-amber-100'
|
| 330 |
+
: 'border-amber-200 bg-amber-50'
|
| 331 |
+
: 'border-slate-200 bg-white/70'
|
| 332 |
+
}`}
|
| 333 |
+
>
|
| 334 |
+
<div
|
| 335 |
+
className={`rounded-lg p-2 transition-colors ${
|
| 336 |
+
isDone
|
| 337 |
+
? 'bg-emerald-100 text-emerald-700'
|
| 338 |
+
: isDoing
|
| 339 |
+
? isStuck
|
| 340 |
+
? 'bg-amber-200 text-amber-800'
|
| 341 |
+
: 'bg-amber-100 text-amber-700'
|
| 342 |
+
: 'bg-slate-100 text-slate-600'
|
| 343 |
+
}`}
|
| 344 |
+
>
|
| 345 |
+
{isDone ? (
|
| 346 |
+
<CheckCircle2 className="w-5 h-5" />
|
| 347 |
+
) : isDoing ? (
|
| 348 |
+
<Loader2 className="w-5 h-5 animate-spin" />
|
| 349 |
+
) : (
|
| 350 |
+
<Icon className="w-5 h-5" />
|
| 351 |
+
)}
|
| 352 |
+
</div>
|
| 353 |
+
<div className="flex-1">
|
| 354 |
+
<div className="text-sm font-medium text-slate-900">{key}</div>
|
| 355 |
+
<div className="text-xs text-slate-600">
|
| 356 |
+
{isDone ? 'Finished' : isDoing ? (isStuck ? 'Running...' : 'Running…') : 'Waiting'}
|
| 357 |
+
</div>
|
| 358 |
+
</div>
|
| 359 |
+
</li>
|
| 360 |
+
);
|
| 361 |
+
})}
|
| 362 |
+
</ol>
|
| 363 |
+
|
| 364 |
+
{/* 微調完成 → 顯示下載模型 */}
|
| 365 |
+
{wantFT && inferred.ftDone && downloadHref && (
|
| 366 |
+
<div className="mt-6">
|
| 367 |
+
<a
|
| 368 |
+
href={downloadHref}
|
| 369 |
+
className="inline-flex items-center rounded-xl border bg-white px-4 py-2 text-sm font-medium text-slate-800 hover:bg-slate-50"
|
| 370 |
+
target="_blank"
|
| 371 |
+
rel="noopener noreferrer"
|
| 372 |
+
>
|
| 373 |
+
<Rocket className="w-4 h-4 mr-2" />
|
| 374 |
+
Download Finetuned Model
|
| 375 |
+
</a>
|
| 376 |
+
{ft?.saved_model_path && (
|
| 377 |
+
<p className="mt-2 text-xs text-slate-500 break-all">
|
| 378 |
+
Model path: {String(ft.saved_model_path)}
|
| 379 |
+
</p>
|
| 380 |
+
)}
|
| 381 |
+
</div>
|
| 382 |
+
)}
|
| 383 |
+
</div>
|
| 384 |
+
</div>
|
| 385 |
+
);
|
| 386 |
+
}
|
frontend/src/components/validators/DatasetValidator.tsx
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/components/validators/DatasetValidator.tsx
|
| 2 |
+
import { useEffect } from 'react';
|
| 3 |
+
import { Loader, CheckCircle, XCircle, ExternalLink } from 'lucide-react';
|
| 4 |
+
import { useHFDatasetValidator } from '../../hooks/useHFValidators';
|
| 5 |
+
|
| 6 |
+
export default function DatasetValidator({ datasetId }: { datasetId: string }) {
|
| 7 |
+
const { loading, result, validate } = useHFDatasetValidator();
|
| 8 |
+
|
| 9 |
+
useEffect(() => {
|
| 10 |
+
validate(datasetId);
|
| 11 |
+
}, [datasetId, validate]);
|
| 12 |
+
|
| 13 |
+
if (!datasetId?.includes('/')) return null;
|
| 14 |
+
|
| 15 |
+
return (
|
| 16 |
+
<div className="mt-3">
|
| 17 |
+
{loading && (
|
| 18 |
+
<div className="flex items-center gap-2 p-3 bg-yellow-50 rounded-lg">
|
| 19 |
+
<Loader className="w-4 h-4 text-yellow-600 animate-spin" />
|
| 20 |
+
<span className="text-sm text-yellow-800">Validating Dataset...</span>
|
| 21 |
+
</div>
|
| 22 |
+
)}
|
| 23 |
+
|
| 24 |
+
{!!result && !loading && (
|
| 25 |
+
<div className={`p-3 rounded-lg ${result.isValid ? 'bg-green-50' : 'bg-red-50'}`}>
|
| 26 |
+
<div className="flex items-start gap-2">
|
| 27 |
+
{result.isValid ? (
|
| 28 |
+
<CheckCircle className="w-4 h-4 text-green-600 mt-0.5" />
|
| 29 |
+
) : (
|
| 30 |
+
<XCircle className="w-4 h-4 text-red-600 mt-0.5" />
|
| 31 |
+
)}
|
| 32 |
+
<div className="flex-1">
|
| 33 |
+
{result.isValid ? (
|
| 34 |
+
<>
|
| 35 |
+
<p className="text-sm font-medium text-green-800">✅ Dataset Verification Successful</p>
|
| 36 |
+
<div className="mt-2 space-y-1 text-xs text-green-700">
|
| 37 |
+
<p>
|
| 38 |
+
<strong>Author: </strong>
|
| 39 |
+
{result.datasetInfo.author}
|
| 40 |
+
</p>
|
| 41 |
+
<p>
|
| 42 |
+
<strong>Download: </strong>
|
| 43 |
+
{result.datasetInfo.downloads.toLocaleString()}
|
| 44 |
+
</p>
|
| 45 |
+
{!!result.datasetInfo.task_categories?.length && (
|
| 46 |
+
<p>
|
| 47 |
+
<strong>Task: </strong>
|
| 48 |
+
{result.datasetInfo.task_categories.join(', ')}
|
| 49 |
+
</p>
|
| 50 |
+
)}
|
| 51 |
+
<p>
|
| 52 |
+
<strong>Description: </strong>
|
| 53 |
+
{result.datasetInfo.description.slice(0, 100)}...
|
| 54 |
+
</p>
|
| 55 |
+
</div>
|
| 56 |
+
<a
|
| 57 |
+
href={`https://huggingface.co/datasets/${datasetId}`}
|
| 58 |
+
target="_blank"
|
| 59 |
+
rel="noreferrer"
|
| 60 |
+
className="inline-flex items-center gap-1 text-xs text-blue-600 hover:text-blue-800 mt-2"
|
| 61 |
+
>
|
| 62 |
+
<ExternalLink className="w-3 h-3" />
|
| 63 |
+
<span>View on Hugging Face</span>
|
| 64 |
+
</a>
|
| 65 |
+
</>
|
| 66 |
+
) : (
|
| 67 |
+
<>
|
| 68 |
+
<p className="text-sm font-medium text-red-800">❌ Dataset Verification Failed</p>
|
| 69 |
+
<p className="text-xs text-red-700 mt-1">{result.error}</p>
|
| 70 |
+
</>
|
| 71 |
+
)}
|
| 72 |
+
</div>
|
| 73 |
+
</div>
|
| 74 |
+
</div>
|
| 75 |
+
)}
|
| 76 |
+
</div>
|
| 77 |
+
);
|
| 78 |
+
}
|
frontend/src/components/validators/ModelValidator.tsx
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/components/validators/ModelValidator.tsx
|
| 2 |
+
import { useEffect } from 'react';
|
| 3 |
+
import { Loader, CheckCircle, XCircle, ExternalLink } from 'lucide-react';
|
| 4 |
+
import { useHFModelValidator } from '../../hooks/useHFValidators';
|
| 5 |
+
|
| 6 |
+
export default function ModelValidator({
|
| 7 |
+
modelId,
|
| 8 |
+
type,
|
| 9 |
+
}: {
|
| 10 |
+
modelId: string;
|
| 11 |
+
type: 'language' | 'scorer';
|
| 12 |
+
}) {
|
| 13 |
+
const { loading, result, validate } = useHFModelValidator();
|
| 14 |
+
|
| 15 |
+
useEffect(() => {
|
| 16 |
+
validate(modelId, type);
|
| 17 |
+
}, [modelId, type, validate]);
|
| 18 |
+
|
| 19 |
+
if (!modelId?.includes('/')) return null;
|
| 20 |
+
|
| 21 |
+
return (
|
| 22 |
+
<div className="mt-3">
|
| 23 |
+
{loading && (
|
| 24 |
+
<div className="flex items-center gap-2 p-3 bg-yellow-50 rounded-lg">
|
| 25 |
+
<Loader className="w-4 h-4 text-yellow-600 animate-spin" />
|
| 26 |
+
<span className="text-sm text-yellow-800">Validating Model...</span>
|
| 27 |
+
</div>
|
| 28 |
+
)}
|
| 29 |
+
|
| 30 |
+
{!!result && !loading && (
|
| 31 |
+
<div className={`p-3 rounded-lg ${result.isValid ? 'bg-green-50' : 'bg-red-50'}`}>
|
| 32 |
+
<div className="flex items-start gap-2">
|
| 33 |
+
{result.isValid ? (
|
| 34 |
+
<CheckCircle className="w-4 h-4 text-green-600 mt-0.5" />
|
| 35 |
+
) : (
|
| 36 |
+
<XCircle className="w-4 h-4 text-red-600 mt-0.5" />
|
| 37 |
+
)}
|
| 38 |
+
<div className="flex-1">
|
| 39 |
+
{result.isValid ? (
|
| 40 |
+
<>
|
| 41 |
+
<p className="text-sm font-medium text-green-800">✅ Model Verification Sucessful</p>
|
| 42 |
+
<div className="mt-2 space-y-1 text-xs text-green-700">
|
| 43 |
+
<p>
|
| 44 |
+
<strong>Author: </strong>
|
| 45 |
+
{result.modelInfo.author}
|
| 46 |
+
</p>
|
| 47 |
+
<p>
|
| 48 |
+
<strong>Download: </strong>
|
| 49 |
+
{result.modelInfo.downloads.toLocaleString()}
|
| 50 |
+
</p>
|
| 51 |
+
{!!result.modelInfo.pipeline_tag && (
|
| 52 |
+
<p>
|
| 53 |
+
<strong>Task: </strong>
|
| 54 |
+
{result.modelInfo.pipeline_tag}
|
| 55 |
+
</p>
|
| 56 |
+
)}
|
| 57 |
+
</div>
|
| 58 |
+
<a
|
| 59 |
+
href={`https://huggingface.co/${modelId}`}
|
| 60 |
+
target="_blank"
|
| 61 |
+
rel="noreferrer"
|
| 62 |
+
className="inline-flex items-center gap-1 text-xs text-blue-600 hover:text-blue-800 mt-2"
|
| 63 |
+
>
|
| 64 |
+
<ExternalLink className="w-3 h-3" />
|
| 65 |
+
<span>View on Hugging Face</span>
|
| 66 |
+
</a>
|
| 67 |
+
</>
|
| 68 |
+
) : (
|
| 69 |
+
<>
|
| 70 |
+
<p className="text-sm font-medium text-red-800">❌ Model Verification Failed</p>
|
| 71 |
+
<p className="text-xs text-red-700 mt-1">{result.error}</p>
|
| 72 |
+
</>
|
| 73 |
+
)}
|
| 74 |
+
</div>
|
| 75 |
+
</div>
|
| 76 |
+
</div>
|
| 77 |
+
)}
|
| 78 |
+
</div>
|
| 79 |
+
);
|
| 80 |
+
}
|
frontend/src/constants/datasets.ts
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { Dataset } from '../types';
|
| 2 |
+
|
| 3 |
+
export const DATASETS: Dataset[] = [
|
| 4 |
+
{
|
| 5 |
+
id: 'AmazonScience/bold',
|
| 6 |
+
name: 'BOLD'
|
| 7 |
+
},
|
| 8 |
+
];
|
| 9 |
+
|
| 10 |
+
export default DATASETS;
|
frontend/src/constants/models.ts
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { Dataset, Model } from '../types';
|
| 2 |
+
|
| 3 |
+
export const DATASETS: Dataset[] = [
|
| 4 |
+
{ id: 'AmazonScience/bold', name: 'BOLD' }
|
| 5 |
+
];
|
| 6 |
+
|
| 7 |
+
export const LM_MODELS: Model[] = [
|
| 8 |
+
{ id: 'microsoft/DialoGPT-large', name: 'DialoGPT-large', type: 'language', description: 'Microsoft 對話生成模型', provider: 'Microsoft' },
|
| 9 |
+
{ id: 'openai-community/gpt2', name: 'GPT-2', type: 'language', description: 'OpenAI GPT-2 基礎模型', provider: 'OpenAI' },
|
| 10 |
+
{ id: 'EleutherAI/gpt-neo-2.7B', name: 'GPT-Neo-2.7B', type: 'language', description: 'EleutherAI 開源語言模型', provider: 'EleutherAI' }
|
| 11 |
+
];
|
frontend/src/hooks/JobRunnerProvider.tsx
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { createContext, useContext, useMemo, useState } from 'react';
|
| 2 |
+
import type { JobConfig, JobResult } from '../types';
|
| 3 |
+
import { MLBiasAPI } from '../services/api';
|
| 4 |
+
|
| 5 |
+
type PipelinePlots = {
|
| 6 |
+
original_sentiment: string;
|
| 7 |
+
counterfactual_sentiment: string;
|
| 8 |
+
};
|
| 9 |
+
|
| 10 |
+
type PipelineResultsDTO = {
|
| 11 |
+
generation_file: string;
|
| 12 |
+
sentiment_subset_file: string;
|
| 13 |
+
cf_sentiment_subset_file: string;
|
| 14 |
+
metrics: {
|
| 15 |
+
finalMeanDiff: number;
|
| 16 |
+
cfFinalMeanDiff: number;
|
| 17 |
+
reductionPct?: number;
|
| 18 |
+
stableCoverage?: number;
|
| 19 |
+
};
|
| 20 |
+
plots: PipelinePlots;
|
| 21 |
+
finetuned_model_zip?: string;
|
| 22 |
+
finetuned_model_dir?: string;
|
| 23 |
+
run_config_files?: {
|
| 24 |
+
json?: string;
|
| 25 |
+
markdown?: string;
|
| 26 |
+
};
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
type PipelineResponseDTO = {
|
| 30 |
+
status: 'success' | 'error';
|
| 31 |
+
message: string;
|
| 32 |
+
timestamp: string;
|
| 33 |
+
results: PipelineResultsDTO;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
type Extras = {
|
| 37 |
+
datasetLimit: number
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
type Ctx = {
|
| 41 |
+
result: JobResult | null;
|
| 42 |
+
resp?: PipelineResponseDTO;
|
| 43 |
+
loading: boolean;
|
| 44 |
+
error?: string;
|
| 45 |
+
start: (cfg: JobConfig, extras: Extras) => Promise<void>;
|
| 46 |
+
url: (p?: string) => string;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
const JobRunnerContext = createContext<Ctx | undefined>(undefined);
|
| 50 |
+
|
| 51 |
+
export function JobRunnerProvider({ children }: { children: React.ReactNode }) {
|
| 52 |
+
const [result, setResult] = useState<JobResult | null>(null);
|
| 53 |
+
const [resp, setResp] = useState<PipelineResponseDTO | undefined>();
|
| 54 |
+
const [loading, setLoading] = useState(false);
|
| 55 |
+
const [error, setErr] = useState<string | undefined>();
|
| 56 |
+
|
| 57 |
+
const start: Ctx['start'] = async (cfg, extras) => {
|
| 58 |
+
setLoading(true);
|
| 59 |
+
setErr(undefined);
|
| 60 |
+
setResp(undefined);
|
| 61 |
+
|
| 62 |
+
const now = new Date().toISOString();
|
| 63 |
+
setResult({
|
| 64 |
+
id: crypto.randomUUID(),
|
| 65 |
+
status: 'running',
|
| 66 |
+
progress: 0,
|
| 67 |
+
config: cfg,
|
| 68 |
+
createdAt: now,
|
| 69 |
+
updatedAt: now,
|
| 70 |
+
});
|
| 71 |
+
|
| 72 |
+
try {
|
| 73 |
+
const cfgToSend = {
|
| 74 |
+
...cfg,
|
| 75 |
+
datasetLimit: extras.datasetLimit
|
| 76 |
+
} as unknown as JobConfig;
|
| 77 |
+
|
| 78 |
+
const r = await MLBiasAPI.runPipeline(cfgToSend as any);
|
| 79 |
+
|
| 80 |
+
setResp(r);
|
| 81 |
+
|
| 82 |
+
const done = new Date().toISOString();
|
| 83 |
+
setResult((prev) => ({
|
| 84 |
+
...(prev as JobResult),
|
| 85 |
+
status: 'completed',
|
| 86 |
+
progress: 100,
|
| 87 |
+
updatedAt: done,
|
| 88 |
+
completedAt: done,
|
| 89 |
+
metrics: {
|
| 90 |
+
finalMeanDiff: r.results.metrics.finalMeanDiff,
|
| 91 |
+
reductionPct: r.results.metrics.reductionPct ?? 0,
|
| 92 |
+
stableCoverage: r.results.metrics.stableCoverage ?? 100,
|
| 93 |
+
},
|
| 94 |
+
}));
|
| 95 |
+
} catch (e: any) {
|
| 96 |
+
setErr(e.message || String(e));
|
| 97 |
+
setResult((prev) =>
|
| 98 |
+
prev
|
| 99 |
+
? { ...prev, status: 'failed', progress: 100, updatedAt: new Date().toISOString() }
|
| 100 |
+
: prev
|
| 101 |
+
);
|
| 102 |
+
} finally {
|
| 103 |
+
setLoading(false);
|
| 104 |
+
}
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
const url = MLBiasAPI.resolvePath;
|
| 108 |
+
|
| 109 |
+
const value = useMemo<Ctx>(
|
| 110 |
+
() => ({ result, resp, loading, error, start, url }),
|
| 111 |
+
[result, resp, loading, error]
|
| 112 |
+
);
|
| 113 |
+
|
| 114 |
+
return <JobRunnerContext.Provider value={value}>{children}</JobRunnerContext.Provider>;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
export function useJobRunner() {
|
| 118 |
+
const ctx = useContext(JobRunnerContext);
|
| 119 |
+
if (!ctx) throw new Error('useJobRunner must be used within JobRunnerProvider');
|
| 120 |
+
return ctx;
|
| 121 |
+
}
|
frontend/src/hooks/useHFValidators.ts
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/hooks/useHFValidators.ts
|
| 2 |
+
import { useMemo, useRef, useState } from 'react';
|
| 3 |
+
import { fetchHFModel, fetchHFDataset } from '../services/hf';
|
| 4 |
+
|
| 5 |
+
const debounce = (fn: (...args: any[]) => void, ms = 350) => {
|
| 6 |
+
let t: any;
|
| 7 |
+
return (...args: any[]) => {
|
| 8 |
+
clearTimeout(t);
|
| 9 |
+
t = setTimeout(() => fn(...args), ms);
|
| 10 |
+
};
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
export function useHFModelValidator() {
|
| 14 |
+
const cache = useRef<Map<string, any>>(new Map());
|
| 15 |
+
const [loading, setLoading] = useState(false);
|
| 16 |
+
const [result, setResult] = useState<any>(null);
|
| 17 |
+
|
| 18 |
+
const validate = useMemo(
|
| 19 |
+
() =>
|
| 20 |
+
debounce(async (modelId: string, expected: 'language' | 'scorer') => {
|
| 21 |
+
if (!modelId?.includes('/')) {
|
| 22 |
+
setResult(null);
|
| 23 |
+
return;
|
| 24 |
+
}
|
| 25 |
+
if (cache.current.has(modelId)) {
|
| 26 |
+
setResult(cache.current.get(modelId));
|
| 27 |
+
return;
|
| 28 |
+
}
|
| 29 |
+
setLoading(true);
|
| 30 |
+
try {
|
| 31 |
+
const info = await fetchHFModel(modelId);
|
| 32 |
+
const id: string = info.id?.toLowerCase() ?? '';
|
| 33 |
+
const tags: string[] = info.tags ?? [];
|
| 34 |
+
let actual: string | undefined = info.pipeline_tag;
|
| 35 |
+
|
| 36 |
+
if (!actual) {
|
| 37 |
+
if (id.includes('t5') || tags.includes('text2text-generation')) actual = 'text2text-generation';
|
| 38 |
+
else if (tags.includes('text-generation')) actual = 'text-generation';
|
| 39 |
+
else if (tags.includes('text-classification')) actual = 'text-classification';
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
const ok =
|
| 43 |
+
expected === 'language'
|
| 44 |
+
? ['text-generation', 'text2text-generation'].includes(actual || '')
|
| 45 |
+
: actual === 'text-classification';
|
| 46 |
+
|
| 47 |
+
const payload = ok
|
| 48 |
+
? {
|
| 49 |
+
isValid: true,
|
| 50 |
+
modelInfo: {
|
| 51 |
+
id: info.id,
|
| 52 |
+
downloads: info.downloads ?? 0,
|
| 53 |
+
pipeline_tag: actual,
|
| 54 |
+
tags,
|
| 55 |
+
author: info.id?.split('/')?.[0] ?? 'unknown',
|
| 56 |
+
modelName: info.id?.split('/')?.[1] ?? info.id,
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
: {
|
| 60 |
+
isValid: false,
|
| 61 |
+
error:
|
| 62 |
+
expected === 'language'
|
| 63 |
+
? `Model task should be text-generation or text2text-generation, but is ${actual || 'Unknown'}」`
|
| 64 |
+
: `Model task should be text-classification, but is${actual || 'Unknown'}」`,
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
cache.current.set(modelId, payload);
|
| 68 |
+
setResult(payload);
|
| 69 |
+
} catch (e: any) {
|
| 70 |
+
setResult({ isValid: false, error: e?.message || 'Error when valiating model' });
|
| 71 |
+
} finally {
|
| 72 |
+
setLoading(false);
|
| 73 |
+
}
|
| 74 |
+
}),
|
| 75 |
+
[]
|
| 76 |
+
);
|
| 77 |
+
|
| 78 |
+
return { loading, result, validate };
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
export function useHFDatasetValidator() {
|
| 82 |
+
const cache = useRef<Map<string, any>>(new Map());
|
| 83 |
+
const [loading, setLoading] = useState(false);
|
| 84 |
+
const [result, setResult] = useState<any>(null);
|
| 85 |
+
|
| 86 |
+
const validate = useMemo(
|
| 87 |
+
() =>
|
| 88 |
+
debounce(async (datasetId: string) => {
|
| 89 |
+
if (!datasetId?.includes('/')) {
|
| 90 |
+
setResult(null);
|
| 91 |
+
return;
|
| 92 |
+
}
|
| 93 |
+
if (cache.current.has(datasetId)) {
|
| 94 |
+
setResult(cache.current.get(datasetId));
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
const valid = /^[a-zA-Z0-9._-]+\/[a-zA-Z0-9._-]+$/.test(datasetId);
|
| 98 |
+
if (!valid) {
|
| 99 |
+
setResult({ isValid: false, error: 'Incorrect Dataset ID Format' });
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
setLoading(true);
|
| 103 |
+
try {
|
| 104 |
+
const info = await fetchHFDataset(datasetId);
|
| 105 |
+
const payload = {
|
| 106 |
+
isValid: true,
|
| 107 |
+
datasetInfo: {
|
| 108 |
+
id: info.id,
|
| 109 |
+
author: info.id?.split('/')?.[0] ?? 'unknown',
|
| 110 |
+
datasetName: info.id?.split('/')?.[1] ?? info.id,
|
| 111 |
+
downloads: info.downloads ?? 0,
|
| 112 |
+
tags: info.tags ?? [],
|
| 113 |
+
description: info.description ?? 'No Description',
|
| 114 |
+
task_categories: info.task_categories ?? [],
|
| 115 |
+
},
|
| 116 |
+
};
|
| 117 |
+
cache.current.set(datasetId, payload);
|
| 118 |
+
setResult(payload);
|
| 119 |
+
} catch (e: any) {
|
| 120 |
+
setResult({ isValid: false, error: e?.message || 'An error occurred while validating the dataset' });
|
| 121 |
+
} finally {
|
| 122 |
+
setLoading(false);
|
| 123 |
+
}
|
| 124 |
+
}),
|
| 125 |
+
[]
|
| 126 |
+
);
|
| 127 |
+
|
| 128 |
+
return { loading, result, validate };
|
| 129 |
+
}
|
frontend/src/hooks/useIterationData.ts
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useMemo } from 'react';
|
| 2 |
+
|
| 3 |
+
export function useIterationData(seed = 42, points = 50) {
|
| 4 |
+
function mulberry32(a:number){return function(){let t=(a+=0x6D2B79F5);t=Math.imul(t^(t>>>15),t|1);t^=t+Math.imul(t^(t>>>7),t|61);return ((t^(t>>>14))>>>0)/4294967296;};}
|
| 5 |
+
return useMemo(() => {
|
| 6 |
+
const rand = mulberry32(seed);
|
| 7 |
+
return Array.from({ length: points }, (_, i) => ({
|
| 8 |
+
iteration: i+1,
|
| 9 |
+
meanDifference: Math.max(0.1, 0.8 - i*0.012 + rand()*0.1),
|
| 10 |
+
groupA: 0.7 - i*0.006 + rand()*0.05,
|
| 11 |
+
groupB: 0.3 + i*0.003 + rand()*0.05,
|
| 12 |
+
}));
|
| 13 |
+
}, [seed, points]);
|
| 14 |
+
}
|
frontend/src/hooks/useJobRunner.ts
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useRef, useState } from 'react';
|
| 2 |
+
import type { JobConfig, JobResult } from '../types';
|
| 3 |
+
import type { PipelineResponseDTO } from '../services/api';
|
| 4 |
+
import { MLBiasAPI } from '../services/api';
|
| 5 |
+
|
| 6 |
+
type HealthLike = {
|
| 7 |
+
job_id?: string;
|
| 8 |
+
timestamp?: string;
|
| 9 |
+
updated_at?: string;
|
| 10 |
+
dataset_loaded?: boolean;
|
| 11 |
+
loaded_models?: string[];
|
| 12 |
+
generation_results_available?: boolean;
|
| 13 |
+
finetune_running?: boolean;
|
| 14 |
+
steps?: Record<string, boolean | 'todo' | 'doing' | 'done'>;
|
| 15 |
+
completed?: boolean;
|
| 16 |
+
status?: string;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
type UseJobRunnerReturn = {
|
| 20 |
+
result: JobResult | null;
|
| 21 |
+
resp: PipelineResponseDTO | undefined;
|
| 22 |
+
loading: boolean;
|
| 23 |
+
error?: string;
|
| 24 |
+
|
| 25 |
+
start: (config: JobConfig) => Promise<void>;
|
| 26 |
+
cancel: () => void;
|
| 27 |
+
|
| 28 |
+
jobId: string | null;
|
| 29 |
+
live: {
|
| 30 |
+
health: HealthLike | null;
|
| 31 |
+
steps: Record<string, boolean>;
|
| 32 |
+
updatedAt: string | null;
|
| 33 |
+
finetuneRunning: boolean;
|
| 34 |
+
progressPercent: number;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
url: typeof MLBiasAPI.resolvePath;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
export function useJobRunner(): UseJobRunnerReturn {
|
| 41 |
+
const [jobId, setJobId] = useState<string | null>(null);
|
| 42 |
+
|
| 43 |
+
const [result, setResult] = useState<JobResult | null>(null);
|
| 44 |
+
const [resp, setResp] = useState<PipelineResponseDTO | undefined>();
|
| 45 |
+
|
| 46 |
+
const [loading, setLoading] = useState(false);
|
| 47 |
+
const [error, setErr] = useState<string | undefined>();
|
| 48 |
+
|
| 49 |
+
const [health, setHealth] = useState<HealthLike | null>(null);
|
| 50 |
+
|
| 51 |
+
const pollRef = useRef<number | null>(null);
|
| 52 |
+
const aliveRef = useRef<boolean>(false);
|
| 53 |
+
|
| 54 |
+
const stopPolling = () => {
|
| 55 |
+
if (pollRef.current) {
|
| 56 |
+
window.clearInterval(pollRef.current);
|
| 57 |
+
pollRef.current = null;
|
| 58 |
+
}
|
| 59 |
+
aliveRef.current = false;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
const cancel = () => {
|
| 63 |
+
stopPolling();
|
| 64 |
+
setLoading(false);
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
const progressPercent = useMemo(() => {
|
| 68 |
+
const s = (health?.steps as Record<string, boolean | string>) || {};
|
| 69 |
+
const keys = Object.keys(s);
|
| 70 |
+
if (keys.length === 0) return result?.progress ?? 0;
|
| 71 |
+
let score = 0;
|
| 72 |
+
keys.forEach((k) => {
|
| 73 |
+
const v = s[k];
|
| 74 |
+
if (v === true || v === 'done') score += 1;
|
| 75 |
+
else if (v === 'doing') score += 0.5;
|
| 76 |
+
});
|
| 77 |
+
return Math.max(0, Math.min(100, Math.round((score / keys.length) * 100)));
|
| 78 |
+
}, [health?.steps, result?.progress]);
|
| 79 |
+
|
| 80 |
+
const liveSteps: Record<string, boolean> = useMemo(() => {
|
| 81 |
+
const fromResp = ((resp?.results as any)?.steps || {}) as Record<string, boolean>;
|
| 82 |
+
const fromHealth = ((health?.steps || {}) as Record<string, boolean | string>);
|
| 83 |
+
const normalized: Record<string, boolean> = {};
|
| 84 |
+
Object.keys(fromResp).forEach((k) => (normalized[k] = !!(fromResp as any)[k]));
|
| 85 |
+
Object.keys(fromHealth).forEach((k) => {
|
| 86 |
+
const v = (fromHealth as any)[k];
|
| 87 |
+
normalized[k] = v === true || v === 'done' || v === 'doing';
|
| 88 |
+
});
|
| 89 |
+
return normalized;
|
| 90 |
+
}, [health?.steps, resp?.results]);
|
| 91 |
+
|
| 92 |
+
const pollOnce = async () => {
|
| 93 |
+
try {
|
| 94 |
+
const h = (await MLBiasAPI.checkHealth()) as HealthLike;
|
| 95 |
+
setHealth((prev) => (JSON.stringify(prev) === JSON.stringify(h) ? prev : h));
|
| 96 |
+
|
| 97 |
+
const steps = (h?.steps || {}) as Record<string, boolean | string>;
|
| 98 |
+
const plotsDone =
|
| 99 |
+
!!steps['6_plots_and_metrics'] ||
|
| 100 |
+
(resp?.results as any)?.plots_ready ||
|
| 101 |
+
((resp?.results as any)?.plot_urls?.length ?? 0) > 0;
|
| 102 |
+
|
| 103 |
+
const r4 = !!steps['4_rank_sampling_original'];
|
| 104 |
+
const r5 = !!steps['5_rank_sampling_cf'];
|
| 105 |
+
const samplingDone = r4 && r5;
|
| 106 |
+
|
| 107 |
+
const genAvailable = !!h?.generation_results_available;
|
| 108 |
+
const ftMaybeDone =
|
| 109 |
+
!!steps['7_finetune'] ||
|
| 110 |
+
(resp?.results as any)?.finetune_done ||
|
| 111 |
+
(resp?.results as any)?.finetune?.completed;
|
| 112 |
+
|
| 113 |
+
const declaredCompleted = h?.completed === true || h?.status === 'completed';
|
| 114 |
+
|
| 115 |
+
if (declaredCompleted || plotsDone || samplingDone || (genAvailable && ftMaybeDone)) {
|
| 116 |
+
stopPolling();
|
| 117 |
+
setLoading(false);
|
| 118 |
+
}
|
| 119 |
+
} catch (e: any) {
|
| 120 |
+
setErr((e && e.message) || String(e));
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
const start = async (config: JobConfig) => {
|
| 125 |
+
setLoading(true);
|
| 126 |
+
setErr(undefined);
|
| 127 |
+
|
| 128 |
+
const now = new Date().toISOString();
|
| 129 |
+
const provisionalId = crypto.randomUUID();
|
| 130 |
+
setResult({
|
| 131 |
+
id: provisionalId,
|
| 132 |
+
status: 'running',
|
| 133 |
+
progress: 0,
|
| 134 |
+
config,
|
| 135 |
+
createdAt: now,
|
| 136 |
+
updatedAt: now,
|
| 137 |
+
});
|
| 138 |
+
setResp(undefined);
|
| 139 |
+
setHealth(null);
|
| 140 |
+
|
| 141 |
+
try {
|
| 142 |
+
const runResp: any = await MLBiasAPI.runPipeline(config);
|
| 143 |
+
|
| 144 |
+
const jid: string | undefined =
|
| 145 |
+
runResp?.jobId || runResp?.job_id || runResp?.results?.jobId || runResp?.results?.job_id;
|
| 146 |
+
setJobId(jid || provisionalId);
|
| 147 |
+
|
| 148 |
+
if (runResp?.results?.metrics) {
|
| 149 |
+
const final = runResp as PipelineResponseDTO;
|
| 150 |
+
const now2 = new Date().toISOString();
|
| 151 |
+
|
| 152 |
+
setResp(final);
|
| 153 |
+
setResult({
|
| 154 |
+
id: jid || provisionalId,
|
| 155 |
+
status: 'completed',
|
| 156 |
+
progress: 100,
|
| 157 |
+
config,
|
| 158 |
+
createdAt: now,
|
| 159 |
+
updatedAt: now2,
|
| 160 |
+
completedAt: now2,
|
| 161 |
+
metrics: {
|
| 162 |
+
finalMeanDiff: final.results.metrics.finalMeanDiff,
|
| 163 |
+
reductionPct: final.results.metrics.reductionPct ?? 0,
|
| 164 |
+
stableCoverage: final.results.metrics.stableCoverage ?? 100,
|
| 165 |
+
},
|
| 166 |
+
});
|
| 167 |
+
setLoading(false);
|
| 168 |
+
return;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
aliveRef.current = true;
|
| 172 |
+
await pollOnce();
|
| 173 |
+
if (aliveRef.current) {
|
| 174 |
+
pollRef.current = window.setInterval(pollOnce, 1000);
|
| 175 |
+
}
|
| 176 |
+
} catch (e: any) {
|
| 177 |
+
setErr(e?.message || String(e));
|
| 178 |
+
setResult((prev) =>
|
| 179 |
+
prev
|
| 180 |
+
? { ...prev, status: 'failed', progress: 100, updatedAt: new Date().toISOString() }
|
| 181 |
+
: null
|
| 182 |
+
);
|
| 183 |
+
setLoading(false);
|
| 184 |
+
}
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
useEffect(() => stopPolling, []);
|
| 188 |
+
|
| 189 |
+
const url = MLBiasAPI.resolvePath;
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
result,
|
| 193 |
+
resp,
|
| 194 |
+
loading,
|
| 195 |
+
error,
|
| 196 |
+
start,
|
| 197 |
+
cancel,
|
| 198 |
+
jobId,
|
| 199 |
+
live: {
|
| 200 |
+
health,
|
| 201 |
+
steps: liveSteps,
|
| 202 |
+
updatedAt: (health && (health.updated_at || health.timestamp)) || null,
|
| 203 |
+
finetuneRunning: !!(health?.finetune_running || (resp as any)?.results?.finetune?.running),
|
| 204 |
+
progressPercent,
|
| 205 |
+
},
|
| 206 |
+
url,
|
| 207 |
+
};
|
| 208 |
+
}
|
frontend/src/index.css
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@tailwind base;
|
| 2 |
+
@tailwind components;
|
| 3 |
+
@tailwind utilities;
|
| 4 |
+
|
frontend/src/main.tsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react'
|
| 2 |
+
import ReactDOM from 'react-dom/client'
|
| 3 |
+
import App from './App.tsx'
|
| 4 |
+
import './index.css'
|
| 5 |
+
|
| 6 |
+
ReactDOM.createRoot(document.getElementById('root')!).render(
|
| 7 |
+
<React.StrictMode>
|
| 8 |
+
<App />
|
| 9 |
+
</React.StrictMode>,
|
| 10 |
+
)
|
frontend/src/pages/ConfigPage.tsx
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState } from 'react';
|
| 2 |
+
import { Database, Bot, ExternalLink, Shuffle } from 'lucide-react';
|
| 3 |
+
import DatasetValidator from '../components/validators/DatasetValidator';
|
| 4 |
+
import ModelValidator from '../components/validators/ModelValidator';
|
| 5 |
+
import { DATASETS } from '../constants/datasets';
|
| 6 |
+
import { LM_MODELS} from '../constants/models';
|
| 7 |
+
import type { JobConfig } from '../types';
|
| 8 |
+
|
| 9 |
+
type Extras = {
|
| 10 |
+
datasetLimit: number,
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
export default function ConfigPage({ onRun }: { onRun: (cfg: JobConfig, extras: Extras) => void }) {
|
| 14 |
+
const [cfg, setCfg] = useState<JobConfig>({
|
| 15 |
+
dataset: '',
|
| 16 |
+
languageModel: '',
|
| 17 |
+
scorerModel: '',
|
| 18 |
+
k: 5,
|
| 19 |
+
numCounterfactuals: 3,
|
| 20 |
+
metrictarget: 0.5,
|
| 21 |
+
tau: 0.1,
|
| 22 |
+
iterations: 1000,
|
| 23 |
+
seed: 42,
|
| 24 |
+
enableFineTuning: false,
|
| 25 |
+
counterfactual: false,
|
| 26 |
+
});
|
| 27 |
+
|
| 28 |
+
const [datasetLimit, setDatasetLimit] = useState<number>(10);
|
| 29 |
+
const [customDataset, setCustomDataset] = useState('');
|
| 30 |
+
const [customLM, setCustomLM] = useState('');
|
| 31 |
+
const [showCustomDatasetInput, setShowCustomDatasetInput] = useState(false);
|
| 32 |
+
const [showCustomLanguageInput, setShowCustomLanguageInput] = useState(false);
|
| 33 |
+
const [fieldStats, setFieldStats] = useState<Record<string, Record<string, number>>>({});
|
| 34 |
+
const [numCounterfactuals, setNumCounterfactuals] = useState<number>(3);
|
| 35 |
+
const [classificationTask, setClassificationTask] = useState<'sentiment' | 'regard' | 'stereotype' | 'personality' | 'toxicity'>('sentiment');
|
| 36 |
+
const [toxicityModelChoice, setToxicityModelChoice] = useState<'detoxify' | 'junglelee'>('detoxify');
|
| 37 |
+
const [selectedCfFields, setSelectedCfFields] = useState<string[]>([]);
|
| 38 |
+
const [availableFields, setAvailableFields] = useState<string[]>([]);
|
| 39 |
+
const [isLoadingFields, setIsLoadingFields] = useState(false);
|
| 40 |
+
const [fieldsError, setFieldsError] = useState<string | null>(null);
|
| 41 |
+
const [metaConfigs, setMetaConfigs] = useState<string[]>([]);
|
| 42 |
+
const [metaSplits, setMetaSplits] = useState<string[]>([]);
|
| 43 |
+
const [selectedConfig, setSelectedConfig] = useState<string | null>(null);
|
| 44 |
+
const [selectedSplit, setSelectedSplit] = useState<string>('train');
|
| 45 |
+
|
| 46 |
+
const canStart = !!(cfg.dataset && cfg.languageModel);
|
| 47 |
+
const [ftEpochs, setFtEpochs] = useState(3);
|
| 48 |
+
const [ftBatchSize, setFtBatchSize] = useState(8);
|
| 49 |
+
const [ftLR, setFtLR] = useState(5e-5);
|
| 50 |
+
const setField = <K extends keyof JobConfig>(k: K, v: JobConfig[K]) =>
|
| 51 |
+
setCfg((prev) => ({ ...prev, [k]: v }));
|
| 52 |
+
|
| 53 |
+
const card = 'group relative rounded-2xl p-8 border border-white/30 bg-white/60 backdrop-blur-xl ' +
|
| 54 |
+
'shadow-[0_15px_40px_-20px_rgba(30,41,59,0.35)] transition-all duration-300 ' +
|
| 55 |
+
'hover:shadow-[0_20px_50px_-20px_rgba(79,70,229,0.45)] hover:-translate-y-0.5';
|
| 56 |
+
|
| 57 |
+
const sectionTitle = 'text-xl font-bold tracking-tight text-slate-900';
|
| 58 |
+
const subtext = 'text-sm text-slate-600';
|
| 59 |
+
const fieldInput = 'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-4 py-3 ' +
|
| 60 |
+
'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all';
|
| 61 |
+
const selectInput = 'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-3 py-2.5 ' +
|
| 62 |
+
'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all';
|
| 63 |
+
const choiceRow = 'flex items-start gap-4 cursor-pointer p-4 rounded-xl border transition-colors ' +
|
| 64 |
+
'bg-white/60 hover:bg-white/80 border-slate-200/60 hover:border-indigo-300';
|
| 65 |
+
|
| 66 |
+
const currentDataset = DATASETS.find((d) => d.id === cfg.dataset);
|
| 67 |
+
const fallbackFields: string[] = (currentDataset as any)?.fields || ['text', 'label', 'group'];
|
| 68 |
+
|
| 69 |
+
const toggleCfField = (f: string) =>
|
| 70 |
+
setSelectedCfFields((prev) =>
|
| 71 |
+
(prev.includes(f) ? prev.filter((x) => x !== f) : [...prev, f])
|
| 72 |
+
);
|
| 73 |
+
|
| 74 |
+
const API_BASE = '/api';
|
| 75 |
+
|
| 76 |
+
async function fetchJSON<T>(url: string, signal?: AbortSignal): Promise<T> {
|
| 77 |
+
const fullURL = url.startsWith('http') ? url : `${API_BASE}${url}`;
|
| 78 |
+
const res = await fetch(fullURL, { signal });
|
| 79 |
+
if (!res.ok) throw new Error(`${res.status} ${res.statusText}`);
|
| 80 |
+
return (await res.json()) as T;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
function buildFieldsURL(datasetId: string, config: string | null, split: string): string {
|
| 85 |
+
const params = new URLSearchParams();
|
| 86 |
+
params.set('id', datasetId);
|
| 87 |
+
|
| 88 |
+
if (config && config.trim() !== '') {
|
| 89 |
+
params.set('config', config);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if (split && split.trim() !== '') {
|
| 93 |
+
params.set('split', split);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
return `/dataset/fields?${params.toString()}`;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
useEffect(() => {
|
| 100 |
+
console.log('📊 Dataset changed:', cfg.dataset);
|
| 101 |
+
setSelectedCfFields([]);
|
| 102 |
+
setFieldsError(null);
|
| 103 |
+
setAvailableFields([]);
|
| 104 |
+
|
| 105 |
+
if (!cfg.dataset || cfg.dataset === 'custom') return;
|
| 106 |
+
|
| 107 |
+
const ac = new AbortController();
|
| 108 |
+
|
| 109 |
+
const run = async () => {
|
| 110 |
+
try {
|
| 111 |
+
console.log('🔍 Fetching dataset meta...');
|
| 112 |
+
const metaURL = `/dataset/meta?id=${encodeURIComponent(cfg.dataset)}`;
|
| 113 |
+
const meta = await fetchJSON<{
|
| 114 |
+
datasetId: string;
|
| 115 |
+
configs: string[];
|
| 116 |
+
splits: string[];
|
| 117 |
+
}>(metaURL, ac.signal);
|
| 118 |
+
|
| 119 |
+
console.log('📋 Meta data received:', meta);
|
| 120 |
+
|
| 121 |
+
setMetaConfigs(meta.configs || []);
|
| 122 |
+
setMetaSplits(meta.splits || []);
|
| 123 |
+
|
| 124 |
+
const defaultConfig = meta.configs?.length ? meta.configs[0] : null;
|
| 125 |
+
const defaultSplit = meta.splits?.length ?
|
| 126 |
+
(meta.splits.includes('train') ? 'train' : meta.splits[0]) :
|
| 127 |
+
'train';
|
| 128 |
+
|
| 129 |
+
setSelectedConfig(defaultConfig);
|
| 130 |
+
setSelectedSplit(defaultSplit);
|
| 131 |
+
|
| 132 |
+
console.log('🏷️ Fetching fields with config:', defaultConfig, 'split:', defaultSplit);
|
| 133 |
+
setIsLoadingFields(true);
|
| 134 |
+
|
| 135 |
+
const fieldsURL = buildFieldsURL(cfg.dataset, defaultConfig, defaultSplit);
|
| 136 |
+
const fieldsData = await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal);
|
| 137 |
+
|
| 138 |
+
setAvailableFields(fieldsData.fields || []);
|
| 139 |
+
setFieldsError(null);
|
| 140 |
+
|
| 141 |
+
} catch (err: any) {
|
| 142 |
+
console.error('❌ Error in dataset effect:', err);
|
| 143 |
+
|
| 144 |
+
setMetaConfigs([]);
|
| 145 |
+
setMetaSplits([]);
|
| 146 |
+
setSelectedConfig(null);
|
| 147 |
+
setSelectedSplit('train');
|
| 148 |
+
|
| 149 |
+
setAvailableFields([]);
|
| 150 |
+
const fieldsURL = buildFieldsURL(cfg.dataset, null, 'train');
|
| 151 |
+
setFieldsError(`(${fieldsURL}) → ${err?.message || '欄位讀取失敗'}`);
|
| 152 |
+
} finally {
|
| 153 |
+
setIsLoadingFields(false);
|
| 154 |
+
}
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
run();
|
| 158 |
+
return () => ac.abort();
|
| 159 |
+
}, [cfg.dataset]);
|
| 160 |
+
|
| 161 |
+
useEffect(() => {
|
| 162 |
+
if (!cfg.dataset || cfg.dataset === 'custom') return;
|
| 163 |
+
|
| 164 |
+
console.log('🔄 Config/Split changed - config:', selectedConfig, 'split:', selectedSplit);
|
| 165 |
+
|
| 166 |
+
const ac = new AbortController();
|
| 167 |
+
|
| 168 |
+
const run = async () => {
|
| 169 |
+
try {
|
| 170 |
+
setIsLoadingFields(true);
|
| 171 |
+
|
| 172 |
+
const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit);
|
| 173 |
+
const fieldsData = await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal);
|
| 174 |
+
|
| 175 |
+
setAvailableFields(fieldsData.fields || []);
|
| 176 |
+
setFieldsError(null);
|
| 177 |
+
setSelectedCfFields([]);
|
| 178 |
+
|
| 179 |
+
const statsURL = `/dataset/field-stats?id=${encodeURIComponent(cfg.dataset)}&field=domain&subfield=category`;
|
| 180 |
+
const statsData = await fetchJSON<{ counts: Record<string, Record<string, number>> }>(statsURL, ac.signal);
|
| 181 |
+
setFieldStats(statsData.counts || {});
|
| 182 |
+
|
| 183 |
+
} catch (err: any) {
|
| 184 |
+
console.error('❌ Error fetching fields after config/split change:', err);
|
| 185 |
+
const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit);
|
| 186 |
+
setAvailableFields([]);
|
| 187 |
+
setFieldsError(`(${fieldsURL}) → ${err?.message || 'Field Read Failed'}`);
|
| 188 |
+
} finally {
|
| 189 |
+
setIsLoadingFields(false);
|
| 190 |
+
}
|
| 191 |
+
};
|
| 192 |
+
|
| 193 |
+
run();
|
| 194 |
+
return () => ac.abort();
|
| 195 |
+
}, [cfg.dataset, selectedConfig, selectedSplit]);
|
| 196 |
+
|
| 197 |
+
return (
|
| 198 |
+
<div className="space-y-10">
|
| 199 |
+
<div className="grid grid-cols-1 lg:grid-cols-6 gap-8">
|
| 200 |
+
{/* 數據集選擇 */}
|
| 201 |
+
<div className={`${card} lg:col-span-3`}>
|
| 202 |
+
<div className="flex items-center gap-3 mb-8">
|
| 203 |
+
<div className="p-3 rounded-xl bg-gradient-to-br from-indigo-600 to-fuchsia-600 shadow-md shadow-indigo-600/30">
|
| 204 |
+
<Database className="w-6 h-6 text-white" />
|
| 205 |
+
</div>
|
| 206 |
+
<h3 className={sectionTitle}>Dataset Selection</h3>
|
| 207 |
+
</div>
|
| 208 |
+
|
| 209 |
+
<div className="space-y-4">
|
| 210 |
+
{DATASETS.map((dataset) => (
|
| 211 |
+
<label key={dataset.id} className={choiceRow}>
|
| 212 |
+
<input
|
| 213 |
+
type="radio"
|
| 214 |
+
name="dataset"
|
| 215 |
+
value={dataset.id}
|
| 216 |
+
checked={cfg.dataset === dataset.id}
|
| 217 |
+
onChange={(e) => {
|
| 218 |
+
setField('dataset', e.target.value);
|
| 219 |
+
setShowCustomDatasetInput(false);
|
| 220 |
+
setCustomDataset('');
|
| 221 |
+
setSelectedCfFields([]);
|
| 222 |
+
}}
|
| 223 |
+
className="mt-1 accent-indigo-600"
|
| 224 |
+
/>
|
| 225 |
+
<div className="flex-1">
|
| 226 |
+
<div className="font-semibold text-slate-900">{dataset.name}</div>
|
| 227 |
+
<div className="flex items-center gap-4 text-xs text-slate-500 mt-2">
|
| 228 |
+
{'entities' in dataset && (
|
| 229 |
+
<span>📊 {(dataset as any).entities?.toLocaleString?.() || '-'} entities</span>
|
| 230 |
+
)}
|
| 231 |
+
{'groups' in dataset && <span>👥 {(dataset as any).groups || '-'} groups</span>}
|
| 232 |
+
</div>
|
| 233 |
+
<a
|
| 234 |
+
href={`https://huggingface.co/datasets/${dataset.id}`}
|
| 235 |
+
target="_blank"
|
| 236 |
+
rel="noopener noreferrer"
|
| 237 |
+
className="inline-flex items-center gap-1 text-indigo-600 hover:text-indigo-700 text-xs font-medium mt-2"
|
| 238 |
+
onClick={(e) => e.stopPropagation()}
|
| 239 |
+
>
|
| 240 |
+
<ExternalLink className="w-3.5 h-3.5" />
|
| 241 |
+
View on Hugging Face
|
| 242 |
+
</a>
|
| 243 |
+
</div>
|
| 244 |
+
</label>
|
| 245 |
+
))}
|
| 246 |
+
|
| 247 |
+
{/* 自訂數據集 */}
|
| 248 |
+
<label className={choiceRow}>
|
| 249 |
+
<input
|
| 250 |
+
type="radio"
|
| 251 |
+
name="dataset"
|
| 252 |
+
value="custom"
|
| 253 |
+
checked={cfg.dataset === 'custom'}
|
| 254 |
+
onChange={(e) => {
|
| 255 |
+
setField('dataset', e.target.value);
|
| 256 |
+
setShowCustomDatasetInput(true);
|
| 257 |
+
setSelectedCfFields([]);
|
| 258 |
+
}}
|
| 259 |
+
className="mt-1 accent-fuchsia-600"
|
| 260 |
+
/>
|
| 261 |
+
<div className="flex-1">
|
| 262 |
+
<div className="font-semibold text-slate-900">🔧 Custom Dataset Upload from Hugging Face</div>
|
| 263 |
+
</div>
|
| 264 |
+
</label>
|
| 265 |
+
|
| 266 |
+
{showCustomDatasetInput && (
|
| 267 |
+
<div className="pl-6 space-y-3 animate-in slide-in-from-top duration-300">
|
| 268 |
+
<input
|
| 269 |
+
type="text"
|
| 270 |
+
placeholder="Input Hugging Face Dataset ID (e.g. AmazonScience/bold)"
|
| 271 |
+
value={customDataset}
|
| 272 |
+
onChange={(e) => {
|
| 273 |
+
setCustomDataset(e.target.value);
|
| 274 |
+
setField('dataset', e.target.value);
|
| 275 |
+
}}
|
| 276 |
+
className={fieldInput}
|
| 277 |
+
/>
|
| 278 |
+
{customDataset && customDataset.includes('/') && (
|
| 279 |
+
<DatasetValidator datasetId={customDataset} />
|
| 280 |
+
)}
|
| 281 |
+
</div>
|
| 282 |
+
)}
|
| 283 |
+
|
| 284 |
+
{cfg.dataset === 'AmazonScience/bold' && !showCustomDatasetInput && (
|
| 285 |
+
<DatasetValidator datasetId="AmazonScience/bold" />
|
| 286 |
+
)}
|
| 287 |
+
</div>
|
| 288 |
+
</div>
|
| 289 |
+
|
| 290 |
+
{/* 反事實分析(置於中間欄) */}
|
| 291 |
+
<div className={`${card} lg:col-span-3`}>
|
| 292 |
+
<div className="flex items-center gap-3 mb-8">
|
| 293 |
+
<div className="p-3 rounded-xl bg-gradient-to-br from-pink-600 to-rose-600 shadow-md shadow-pink-600/30">
|
| 294 |
+
<Shuffle className="w-6 h-6 text-white" />
|
| 295 |
+
</div>
|
| 296 |
+
<h3 className={sectionTitle}>Counterfactual Setting</h3>
|
| 297 |
+
</div>
|
| 298 |
+
|
| 299 |
+
<div className="space-y-6">
|
| 300 |
+
|
| 301 |
+
<div className="pt-2">
|
| 302 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 303 |
+
Number of Counterfactual
|
| 304 |
+
</label>
|
| 305 |
+
<input
|
| 306 |
+
type="number"
|
| 307 |
+
min={1}
|
| 308 |
+
max={20}
|
| 309 |
+
step={1}
|
| 310 |
+
value={numCounterfactuals}
|
| 311 |
+
onChange={(e) => {
|
| 312 |
+
const v = parseInt(e.target.value || '3', 10);
|
| 313 |
+
setNumCounterfactuals(Number.isFinite(v) ? Math.max(1, Math.min(20, v)) : 3);
|
| 314 |
+
}}
|
| 315 |
+
className={fieldInput}
|
| 316 |
+
/>
|
| 317 |
+
</div>
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
{/* Dataset meta(若有 configs/splits 就顯示下拉) */}
|
| 321 |
+
{(metaConfigs.length > 0 || metaSplits.length > 0) && (
|
| 322 |
+
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
| 323 |
+
{metaConfigs.length > 0 && (
|
| 324 |
+
<div>
|
| 325 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">Dataset Config</label>
|
| 326 |
+
<select
|
| 327 |
+
value={selectedConfig || ''}
|
| 328 |
+
onChange={(e) => setSelectedConfig(e.target.value || null)}
|
| 329 |
+
className={selectInput}
|
| 330 |
+
>
|
| 331 |
+
{metaConfigs.map((c) => (
|
| 332 |
+
<option key={c} value={c}>{c}</option>
|
| 333 |
+
))}
|
| 334 |
+
</select>
|
| 335 |
+
</div>
|
| 336 |
+
)}
|
| 337 |
+
|
| 338 |
+
{metaSplits.length > 0 && (
|
| 339 |
+
<div>
|
| 340 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">Split</label>
|
| 341 |
+
<select
|
| 342 |
+
value={selectedSplit}
|
| 343 |
+
onChange={(e) => setSelectedSplit(e.target.value)}
|
| 344 |
+
className={selectInput}
|
| 345 |
+
>
|
| 346 |
+
{metaSplits.map((s) => (
|
| 347 |
+
<option key={s} value={s}>{s}</option>
|
| 348 |
+
))}
|
| 349 |
+
</select>
|
| 350 |
+
</div>
|
| 351 |
+
)}
|
| 352 |
+
</div>
|
| 353 |
+
)}
|
| 354 |
+
|
| 355 |
+
{/* 狀態列 */}
|
| 356 |
+
<div className="text-xs text-slate-500 flex items-center gap-2">
|
| 357 |
+
<span>Selected Dataset</span>
|
| 358 |
+
<span className="inline-flex items-center rounded-full bg-slate-800/90 text-white px-2.5 py-1">
|
| 359 |
+
{cfg.dataset || 'Not Selected Yet'}
|
| 360 |
+
</span>
|
| 361 |
+
{selectedConfig && <span className="ml-1">/ {selectedConfig}</span>}
|
| 362 |
+
{selectedSplit && <span className="ml-1">/ {selectedSplit}</span>}
|
| 363 |
+
</div>
|
| 364 |
+
|
| 365 |
+
{/* 欄位清單 */}
|
| 366 |
+
<div>
|
| 367 |
+
<div className="flex items-center justify-between mb-2">
|
| 368 |
+
<div className="text-sm font-semibold text-slate-800">Optional fields</div>
|
| 369 |
+
{isLoadingFields && <span className="text-xs text-slate-500">Loading</span>}
|
| 370 |
+
</div>
|
| 371 |
+
|
| 372 |
+
<div className="space-y-4 max-h-64 overflow-auto pr-1">
|
| 373 |
+
{Object.entries(fieldStats).map(([domain, categories]) => (
|
| 374 |
+
<div key={domain} className="bg-white/50 border border-slate-200 rounded-xl p-3 shadow-sm">
|
| 375 |
+
<div className="font-semibold text-slate-700 text-sm mb-2">{domain}</div>
|
| 376 |
+
<div className="grid grid-cols-1 sm:grid-cols-2 gap-x-4 gap-y-2 pl-1">
|
| 377 |
+
{Object.entries(categories).map(([category, count]) => {
|
| 378 |
+
const fieldKey = `${domain}/${category}`;
|
| 379 |
+
return (
|
| 380 |
+
<label
|
| 381 |
+
key={fieldKey}
|
| 382 |
+
className="flex items-center gap-2 text-sm text-slate-800 hover:bg-white/60 px-2 py-1 rounded-md transition-colors"
|
| 383 |
+
>
|
| 384 |
+
<input
|
| 385 |
+
type="checkbox"
|
| 386 |
+
checked={selectedCfFields.includes(fieldKey)}
|
| 387 |
+
onChange={() =>
|
| 388 |
+
setSelectedCfFields((prev) =>
|
| 389 |
+
prev.includes(fieldKey)
|
| 390 |
+
? prev.filter((x) => x !== fieldKey)
|
| 391 |
+
: [...prev, fieldKey]
|
| 392 |
+
)
|
| 393 |
+
}
|
| 394 |
+
className="accent-fuchsia-600"
|
| 395 |
+
/>
|
| 396 |
+
<span>{category}</span>
|
| 397 |
+
<span className="text-xs text-slate-500">({count})</span>
|
| 398 |
+
</label>
|
| 399 |
+
);
|
| 400 |
+
})}
|
| 401 |
+
</div>
|
| 402 |
+
</div>
|
| 403 |
+
))}
|
| 404 |
+
</div>
|
| 405 |
+
</div>
|
| 406 |
+
</div>
|
| 407 |
+
</div>
|
| 408 |
+
|
| 409 |
+
{/* 模型選擇(包含 K / datasetLimit 與 metrictarget 的指定位置) */}
|
| 410 |
+
<div className={`${card} lg:col-span-3`}>
|
| 411 |
+
<div className="flex items-center gap-3 mb-8">
|
| 412 |
+
<div className="p-3 rounded-xl bg-gradient-to-br from-emerald-600 to-teal-600 shadow-md shadow-emerald-600/30">
|
| 413 |
+
<Bot className="w-6 h-6 text-white" />
|
| 414 |
+
</div>
|
| 415 |
+
<h3 className={sectionTitle}>Model Selection</h3>
|
| 416 |
+
</div>
|
| 417 |
+
|
| 418 |
+
<div className="space-y-8">
|
| 419 |
+
{/* 語言模型 */}
|
| 420 |
+
<div>
|
| 421 |
+
<label className="block text-sm font-semibold text-slate-800 mb-2">🤖 Language Generation Model</label>
|
| 422 |
+
<select
|
| 423 |
+
value={cfg.languageModel}
|
| 424 |
+
onChange={(e) => {
|
| 425 |
+
setField('languageModel', e.target.value);
|
| 426 |
+
setShowCustomLanguageInput(e.target.value === 'custom');
|
| 427 |
+
}}
|
| 428 |
+
className={selectInput}
|
| 429 |
+
>
|
| 430 |
+
<option value="">Select a Language Model</option>
|
| 431 |
+
{LM_MODELS.map((m) => (
|
| 432 |
+
<option key={m.id} value={m.id}>
|
| 433 |
+
{m.name}({m.provider})
|
| 434 |
+
</option>
|
| 435 |
+
))}
|
| 436 |
+
<option value="custom">🔧 Custom Model Upload from Hugging Face</option>
|
| 437 |
+
</select>
|
| 438 |
+
|
| 439 |
+
{showCustomLanguageInput && (
|
| 440 |
+
<input
|
| 441 |
+
type="text"
|
| 442 |
+
placeholder="Input Hugging Face Model ID (e.g.:microsoft/DialoGPT-medium)"
|
| 443 |
+
value={customLM}
|
| 444 |
+
onChange={(e) => {
|
| 445 |
+
setCustomLM(e.target.value);
|
| 446 |
+
setField('languageModel', e.target.value);
|
| 447 |
+
}}
|
| 448 |
+
className={`${fieldInput} mt-3`}
|
| 449 |
+
/>
|
| 450 |
+
)}
|
| 451 |
+
|
| 452 |
+
{(customLM || cfg.languageModel) && (
|
| 453 |
+
<div className="mt-3">
|
| 454 |
+
<ModelValidator modelId={customLM || cfg.languageModel} type="language" />
|
| 455 |
+
</div>
|
| 456 |
+
)}
|
| 457 |
+
|
| 458 |
+
{/* 語言模型下方:K 與 datasetLimit */}
|
| 459 |
+
<div className="mt-6 space-y-5">
|
| 460 |
+
<div>
|
| 461 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 462 |
+
Number of Candidates
|
| 463 |
+
<span className="ml-2 text-xs font-normal text-slate-500">The number of candidates generated for each entity</span>
|
| 464 |
+
</label>
|
| 465 |
+
<input
|
| 466 |
+
type="number"
|
| 467 |
+
min={1}
|
| 468 |
+
max={20}
|
| 469 |
+
value={cfg.k}
|
| 470 |
+
onChange={(e) => setField('k', parseInt(e.target.value || '0', 10))}
|
| 471 |
+
className={fieldInput}
|
| 472 |
+
/>
|
| 473 |
+
</div>
|
| 474 |
+
|
| 475 |
+
<div>
|
| 476 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 477 |
+
Testing Data Limit
|
| 478 |
+
</label>
|
| 479 |
+
<input
|
| 480 |
+
type="number"
|
| 481 |
+
min={1}
|
| 482 |
+
max={10000}
|
| 483 |
+
value={datasetLimit}
|
| 484 |
+
onChange={(e) => setDatasetLimit(parseInt(e.target.value || '0', 10))}
|
| 485 |
+
className={fieldInput}
|
| 486 |
+
/>
|
| 487 |
+
</div>
|
| 488 |
+
</div>
|
| 489 |
+
</div>
|
| 490 |
+
|
| 491 |
+
{/* 分類任務(固定選項) */}
|
| 492 |
+
<div className="mt-6">
|
| 493 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 494 |
+
👻 Feature Extraction Model
|
| 495 |
+
</label>
|
| 496 |
+
<select
|
| 497 |
+
value={classificationTask}
|
| 498 |
+
onChange={(e) => setClassificationTask(e.target.value as any)}
|
| 499 |
+
className={selectInput}
|
| 500 |
+
>
|
| 501 |
+
<option value="sentiment">Sentiment (0–1, Neutral ≈ 0.5)</option>
|
| 502 |
+
<option value="regard">Regard (0–2, Neutral ≈ 1.0)</option>
|
| 503 |
+
<option value="stereotype">Stereotype (0–1, Neutral ≈ 0.0)</option>
|
| 504 |
+
<option value="personality">Personality (0–1, Neutral ≈ 0.2)</option>
|
| 505 |
+
<option value="toxicity">Toxicity (0–1, Neutral ≈ 0.0)</option>
|
| 506 |
+
|
| 507 |
+
</select>
|
| 508 |
+
</div>
|
| 509 |
+
|
| 510 |
+
{/* 毒性模型選擇(只有當任務為 toxicity 時顯示) */}
|
| 511 |
+
{classificationTask === 'toxicity' && (
|
| 512 |
+
<div className="mt-4">
|
| 513 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 514 |
+
Toxicity Model Selection
|
| 515 |
+
</label>
|
| 516 |
+
<select
|
| 517 |
+
value={toxicityModelChoice}
|
| 518 |
+
onChange={(e) => setToxicityModelChoice(e.target.value as any)}
|
| 519 |
+
className={selectInput}
|
| 520 |
+
>
|
| 521 |
+
<option value="detoxify">unitary/toxic-bert(detoxify)</option>
|
| 522 |
+
<option value="junglelee">JungleLee/bert-toxic-comment-classification</option>
|
| 523 |
+
</select>
|
| 524 |
+
</div>
|
| 525 |
+
)}
|
| 526 |
+
|
| 527 |
+
{/* 評分模型下方:目標指標值 */}
|
| 528 |
+
<div className="mt-6">
|
| 529 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 530 |
+
Metric Target Value
|
| 531 |
+
<span className="ml-2 text-xs font-normal text-slate-500">Indicator thresholds used to determine compliance</span>
|
| 532 |
+
</label>
|
| 533 |
+
<input
|
| 534 |
+
type="number"
|
| 535 |
+
min={0}
|
| 536 |
+
max={2}
|
| 537 |
+
step={0.01}
|
| 538 |
+
value={cfg.metrictarget}
|
| 539 |
+
onChange={(e) => setField('metrictarget', parseFloat(e.target.value || '0'))}
|
| 540 |
+
className={fieldInput}
|
| 541 |
+
/>
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
</div>
|
| 545 |
+
</div>
|
| 546 |
+
</div>
|
| 547 |
+
{/* Fine-tuning 設定 */}
|
| 548 |
+
<div className={`${card} lg:col-span-3`}>
|
| 549 |
+
<div className="flex items-center gap-3 mb-8">
|
| 550 |
+
<div className="p-3 rounded-xl bg-gradient-to-br from-orange-500 to-yellow-500 shadow-md shadow-orange-500/30">
|
| 551 |
+
<Database className="w-6 h-6 text-white" />
|
| 552 |
+
</div>
|
| 553 |
+
<h3 className={sectionTitle}>Fine-tuning Setting</h3>
|
| 554 |
+
</div>
|
| 555 |
+
|
| 556 |
+
<div className="space-y-6">
|
| 557 |
+
<label className="flex items-center gap-2">
|
| 558 |
+
<input
|
| 559 |
+
type="checkbox"
|
| 560 |
+
checked={cfg.enableFineTuning}
|
| 561 |
+
onChange={(e) => setField('enableFineTuning', e.target.checked)}
|
| 562 |
+
className="accent-orange-500"
|
| 563 |
+
/>
|
| 564 |
+
<span className="text-sm text-slate-800 font-semibold">Enable Fine-tuning</span>
|
| 565 |
+
</label>
|
| 566 |
+
|
| 567 |
+
{cfg.enableFineTuning && (
|
| 568 |
+
<div className="space-y-4 pl-4 border-l-2 border-orange-200">
|
| 569 |
+
{/* Epochs */}
|
| 570 |
+
<div>
|
| 571 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 572 |
+
Training Epochs
|
| 573 |
+
</label>
|
| 574 |
+
<input
|
| 575 |
+
type="number"
|
| 576 |
+
min={1}
|
| 577 |
+
max={100}
|
| 578 |
+
value={ftEpochs}
|
| 579 |
+
onChange={(e) => setFtEpochs(parseInt(e.target.value || '3', 10))}
|
| 580 |
+
className={fieldInput}
|
| 581 |
+
/>
|
| 582 |
+
</div>
|
| 583 |
+
|
| 584 |
+
{/* Batch Size */}
|
| 585 |
+
<div>
|
| 586 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 587 |
+
Batch Size
|
| 588 |
+
</label>
|
| 589 |
+
<input
|
| 590 |
+
type="number"
|
| 591 |
+
min={1}
|
| 592 |
+
max={256}
|
| 593 |
+
value={ftBatchSize}
|
| 594 |
+
onChange={(e) => setFtBatchSize(parseInt(e.target.value || '8', 10))}
|
| 595 |
+
className={fieldInput}
|
| 596 |
+
/>
|
| 597 |
+
</div>
|
| 598 |
+
|
| 599 |
+
{/* Learning Rate */}
|
| 600 |
+
<div>
|
| 601 |
+
<label className="block text-sm font-semibold text-slate-800 mb-1">
|
| 602 |
+
Learning Rate
|
| 603 |
+
</label>
|
| 604 |
+
<input
|
| 605 |
+
type="number"
|
| 606 |
+
step={0.00001}
|
| 607 |
+
value={ftLR}
|
| 608 |
+
onChange={(e) => setFtLR(parseFloat(e.target.value || '0.00005'))}
|
| 609 |
+
className={fieldInput}
|
| 610 |
+
/>
|
| 611 |
+
</div>
|
| 612 |
+
</div>
|
| 613 |
+
)}
|
| 614 |
+
</div>
|
| 615 |
+
</div>
|
| 616 |
+
|
| 617 |
+
</div>
|
| 618 |
+
|
| 619 |
+
{/* 開始按鈕 */}
|
| 620 |
+
<div className="flex">
|
| 621 |
+
<button
|
| 622 |
+
onClick={() => {
|
| 623 |
+
const fullCfg = {
|
| 624 |
+
...cfg,
|
| 625 |
+
selectedCfFields,
|
| 626 |
+
numCounterfactuals,
|
| 627 |
+
classificationTask,
|
| 628 |
+
toxicityModelChoice,
|
| 629 |
+
finetuneParams: {
|
| 630 |
+
epochs: ftEpochs,
|
| 631 |
+
batchSize: ftBatchSize,
|
| 632 |
+
learningRate: ftLR,
|
| 633 |
+
},
|
| 634 |
+
};
|
| 635 |
+
onRun(fullCfg, {
|
| 636 |
+
datasetLimit
|
| 637 |
+
});
|
| 638 |
+
}}
|
| 639 |
+
disabled={!canStart}
|
| 640 |
+
className="relative w-full group overflow-hidden rounded-2xl px-6 py-4 text-white font-semibold bg-gradient-to-r from-indigo-600 via-violet-600 to-fuchsia-600 shadow-lg shadow-indigo-600/20 enabled:hover:shadow-indigo-600/40 transition-all enabled:hover:translate-y-[-1px] enabled:active:translate-y-0 disabled:opacity-60 disabled:cursor-not-allowed"
|
| 641 |
+
>
|
| 642 |
+
<span className="relative z-10">🚀 Start</span>
|
| 643 |
+
<span className="absolute inset-0 opacity-0 group-hover:opacity-100 transition-opacity bg-[radial-gradient(1200px_200px_at_50%_-40%,rgba(255,255,255,0.35),transparent_60%)]" />
|
| 644 |
+
</button>
|
| 645 |
+
</div>
|
| 646 |
+
|
| 647 |
+
</div>
|
| 648 |
+
);
|
| 649 |
+
}
|
frontend/src/pages/ResultsPage.tsx
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PipelineProgress from '../components/PipelineProgress';
|
| 2 |
+
import { useJobRunner } from '../hooks/JobRunnerProvider';
|
| 3 |
+
|
| 4 |
+
export default function ResultsPage() {
|
| 5 |
+
const { result, resp, loading, error, url } = useJobRunner();
|
| 6 |
+
|
| 7 |
+
if (loading && !resp) {
|
| 8 |
+
return (
|
| 9 |
+
<div className="space-y-6">
|
| 10 |
+
<PipelineProgress />
|
| 11 |
+
<section className="grid grid-cols-1 md:grid-cols-2 gap-6">
|
| 12 |
+
{[0, 1].map((i) => (
|
| 13 |
+
<div key={i} className="rounded-2xl border bg-white/70 backdrop-blur p-3">
|
| 14 |
+
<div className="w-full h-64 rounded-xl bg-gradient-to-b from-slate-200 to-slate-100 animate-pulse" />
|
| 15 |
+
<div className="mt-3 h-4 w-40 rounded bg-slate-200 animate-pulse" />
|
| 16 |
+
</div>
|
| 17 |
+
))}
|
| 18 |
+
</section>
|
| 19 |
+
</div>
|
| 20 |
+
);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
if (error) {
|
| 24 |
+
return (
|
| 25 |
+
<div className="p-6 rounded-2xl bg-red-50 border border-red-200 text-red-700">
|
| 26 |
+
{error}
|
| 27 |
+
</div>
|
| 28 |
+
);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
if (!result || !resp) {
|
| 32 |
+
return (
|
| 33 |
+
<div className="p-6 rounded-2xl bg-white/70 border border-white/40">
|
| 34 |
+
Task not executed yet
|
| 35 |
+
</div>
|
| 36 |
+
);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
const m = result.metrics!;
|
| 40 |
+
const plots = resp.results.plots;
|
| 41 |
+
|
| 42 |
+
const originalSrc = url(plots.original_sentiment);
|
| 43 |
+
const cfSrc = url(plots.counterfactual_sentiment);
|
| 44 |
+
|
| 45 |
+
const r = resp.results as any;
|
| 46 |
+
const links: { label: string; href: string }[] = [];
|
| 47 |
+
|
| 48 |
+
if (r?.generation_file) {
|
| 49 |
+
links.push({ label: 'Generation CSV', href: r.generation_file });
|
| 50 |
+
}
|
| 51 |
+
if (r?.sentiment_subset_file) {
|
| 52 |
+
links.push({ label: 'Original sentiment subset CSV', href: r.sentiment_subset_file });
|
| 53 |
+
}
|
| 54 |
+
if (r?.cf_sentiment_subset_file) {
|
| 55 |
+
links.push({ label: 'CF sentiment subset CSV', href: r.cf_sentiment_subset_file });
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
if (r?.run_config_files?.markdown) {
|
| 59 |
+
links.push({ label: 'Run Config (Markdown)', href: r.run_config_files.markdown });
|
| 60 |
+
}
|
| 61 |
+
if (r?.run_config_files?.json) {
|
| 62 |
+
links.push({ label: 'Run Config (JSON)', href: r.run_config_files.json });
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if (r?.finetuned_model_zip) {
|
| 66 |
+
links.push({ label: 'Fine-tuned Model (ZIP)', href: r.finetuned_model_zip });
|
| 67 |
+
} else if (r?.finetuned_model_dir) {
|
| 68 |
+
links.push({ label: 'Fine-tuned Model Folder', href: r.finetuned_model_dir });
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return (
|
| 72 |
+
<div className="space-y-6">
|
| 73 |
+
{loading && <PipelineProgress />}
|
| 74 |
+
|
| 75 |
+
<section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
|
| 76 |
+
<h2 className="text-lg font-semibold mb-3">Metric</h2>
|
| 77 |
+
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
| 78 |
+
<div className="p-4 rounded-xl bg-slate-50 border">
|
| 79 |
+
<div className="text-slate-500 text-sm">Original Difference</div>
|
| 80 |
+
<div className="text-2xl font-bold">{m.finalMeanDiff.toFixed(4)}</div>
|
| 81 |
+
</div>
|
| 82 |
+
<div className="p-4 rounded-xl bg-slate-50 border">
|
| 83 |
+
<div className="text-slate-500 text-sm">CF Difference</div>
|
| 84 |
+
<div className="text-2xl font-bold">
|
| 85 |
+
{resp.results.metrics.cfFinalMeanDiff.toFixed(4)}
|
| 86 |
+
</div>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
</section>
|
| 90 |
+
|
| 91 |
+
<section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
|
| 92 |
+
<h2 className="text-lg font-semibold mb-4">Distribution</h2>
|
| 93 |
+
<div className="grid grid-cols-1 md:grid-cols-2 gap-6">
|
| 94 |
+
<figure className="rounded-xl overflow-hidden border bg-white">
|
| 95 |
+
<img
|
| 96 |
+
src={originalSrc}
|
| 97 |
+
alt="Original distribution"
|
| 98 |
+
className="w-full h-auto"
|
| 99 |
+
loading="lazy"
|
| 100 |
+
onError={(e) => {
|
| 101 |
+
e.currentTarget.alt = 'Original image loading failed';
|
| 102 |
+
}}
|
| 103 |
+
/>
|
| 104 |
+
<figcaption className="p-3 text-sm text-slate-600">Original</figcaption>
|
| 105 |
+
</figure>
|
| 106 |
+
|
| 107 |
+
<figure className="rounded-xl overflow-hidden border bg-white">
|
| 108 |
+
<img
|
| 109 |
+
src={cfSrc}
|
| 110 |
+
alt="Counterfactual distribution"
|
| 111 |
+
className="w-full h-auto"
|
| 112 |
+
loading="lazy"
|
| 113 |
+
onError={(e) => {
|
| 114 |
+
e.currentTarget.alt = 'Counterfactual image loading failed';
|
| 115 |
+
}}
|
| 116 |
+
/>
|
| 117 |
+
<figcaption className="p-3 text-sm text-slate-600">Counterfactual Augmented</figcaption>
|
| 118 |
+
</figure>
|
| 119 |
+
</div>
|
| 120 |
+
</section>
|
| 121 |
+
|
| 122 |
+
{links.length > 0 && (
|
| 123 |
+
<section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
|
| 124 |
+
<h2 className="text-lg font-semibold mb-4">Download Report</h2>
|
| 125 |
+
<ul className="space-y-2">
|
| 126 |
+
{links.map((l) => (
|
| 127 |
+
<li key={l.label}>
|
| 128 |
+
<a
|
| 129 |
+
className="text-indigo-600 hover:underline"
|
| 130 |
+
href={url(l.href)}
|
| 131 |
+
target="_blank"
|
| 132 |
+
rel="noreferrer"
|
| 133 |
+
download
|
| 134 |
+
>
|
| 135 |
+
{l.label}
|
| 136 |
+
</a>
|
| 137 |
+
</li>
|
| 138 |
+
))}
|
| 139 |
+
</ul>
|
| 140 |
+
</section>
|
| 141 |
+
)}
|
| 142 |
+
</div>
|
| 143 |
+
);
|
| 144 |
+
}
|
frontend/src/services/api.ts
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/services/api.ts
|
| 2 |
+
import type { JobConfig } from '../types';
|
| 3 |
+
|
| 4 |
+
export type PipelinePlots = {
|
| 5 |
+
original_sentiment: string;
|
| 6 |
+
counterfactual_sentiment: string;
|
| 7 |
+
};
|
| 8 |
+
|
| 9 |
+
export type PipelineResultsDTO = {
|
| 10 |
+
data_loaded: number;
|
| 11 |
+
model_loaded: string;
|
| 12 |
+
|
| 13 |
+
generation_file: string;
|
| 14 |
+
generation_samples: number;
|
| 15 |
+
|
| 16 |
+
counterfactual_file: string;
|
| 17 |
+
counterfactual_added: number;
|
| 18 |
+
counterfactual_total: number;
|
| 19 |
+
|
| 20 |
+
sampling_method: string;
|
| 21 |
+
sentiment_subset_file: string;
|
| 22 |
+
sentiment_subset_size: number;
|
| 23 |
+
|
| 24 |
+
cf_sentiment_subset_file: string;
|
| 25 |
+
cf_sentiment_subset_size: number;
|
| 26 |
+
|
| 27 |
+
// 後端還會給 stereotype 的欄位,但前端不需要可不宣告
|
| 28 |
+
config_used: import('../types').JobConfig;
|
| 29 |
+
metrics: import('../types').JobMetrics & {
|
| 30 |
+
finalMeanDiff: number;
|
| 31 |
+
cfFinalMeanDiff: number;
|
| 32 |
+
};
|
| 33 |
+
plots: PipelinePlots;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
export type PipelineResponseDTO = {
|
| 37 |
+
status: 'success' | 'error';
|
| 38 |
+
message: string;
|
| 39 |
+
timestamp: string;
|
| 40 |
+
results: PipelineResultsDTO;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
const BASE = import.meta.env.VITE_API_BASE ?? '/api';
|
| 44 |
+
|
| 45 |
+
async function runPipeline(config: any) {
|
| 46 |
+
const r = await fetch(`${BASE}/pipeline`, {
|
| 47 |
+
method: 'POST',
|
| 48 |
+
headers: { 'Content-Type': 'application/json' },
|
| 49 |
+
body: JSON.stringify({ config }),
|
| 50 |
+
});
|
| 51 |
+
if (!r.ok) {
|
| 52 |
+
const text = await r.text();
|
| 53 |
+
throw new Error(`Pipeline failed (${r.status}): ${text}`);
|
| 54 |
+
}
|
| 55 |
+
return r.json();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
async function checkHealth() {
|
| 59 |
+
const r = await fetch(`${BASE}/health`);
|
| 60 |
+
return r.json();
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
function resolvePath(p?: string) {
|
| 64 |
+
if (!p) return '';
|
| 65 |
+
if (p.startsWith('http')) return p;
|
| 66 |
+
const path = p.startsWith('/') ? p : `/${p}`;
|
| 67 |
+
return `${BASE}${path}`;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
export const MLBiasAPI = { runPipeline, checkHealth, resolvePath };
|
frontend/src/services/hf.ts
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/services/hf.ts
|
| 2 |
+
export async function fetchHFModel(modelId: string) {
|
| 3 |
+
const r = await fetch(`https://huggingface.co/api/models/${modelId}`);
|
| 4 |
+
if (!r.ok) throw new Error('The model does not exist or cannot be accessed');
|
| 5 |
+
return r.json();
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export async function fetchHFDataset(datasetId: string) {
|
| 9 |
+
const r = await fetch(`https://huggingface.co/api/datasets/${datasetId}`);
|
| 10 |
+
if (!r.ok) throw new Error('The dataset does not exist or cannot be accessed');
|
| 11 |
+
return r.json();
|
| 12 |
+
}
|
frontend/src/types/index.ts
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// src/types/index.ts
|
| 2 |
+
export type Dataset = {
|
| 3 |
+
id: string;
|
| 4 |
+
name: string;
|
| 5 |
+
};
|
| 6 |
+
|
| 7 |
+
export type Model = {
|
| 8 |
+
id: string;
|
| 9 |
+
name: string;
|
| 10 |
+
type: 'language' | 'scorer';
|
| 11 |
+
description: string;
|
| 12 |
+
provider: string;
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
export type JobConfig = {
|
| 16 |
+
dataset: string;
|
| 17 |
+
languageModel: string;
|
| 18 |
+
scorerModel: string;
|
| 19 |
+
k: number;
|
| 20 |
+
tau: number;
|
| 21 |
+
iterations: number;
|
| 22 |
+
seed: number;
|
| 23 |
+
enableFineTuning: boolean;
|
| 24 |
+
counterfactual: boolean;
|
| 25 |
+
metrictarget: number;
|
| 26 |
+
numCounterfactuals: number;
|
| 27 |
+
selectedCfFields?: string[];
|
| 28 |
+
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
export type JobStatus = 'running' | 'completed' | 'failed';
|
| 32 |
+
|
| 33 |
+
export type ChartPoint = {
|
| 34 |
+
iteration: number;
|
| 35 |
+
meanDifference: number;
|
| 36 |
+
groupA: number;
|
| 37 |
+
groupB: number;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
export type JobMetrics = {
|
| 41 |
+
finalMeanDiff: number;
|
| 42 |
+
reductionPct: number;
|
| 43 |
+
stableCoverage: number;
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
export type JobResult = {
|
| 47 |
+
id: string;
|
| 48 |
+
status: JobStatus;
|
| 49 |
+
progress: number; // 0-100
|
| 50 |
+
config: JobConfig;
|
| 51 |
+
createdAt: string;
|
| 52 |
+
updatedAt: string;
|
| 53 |
+
completedAt?: string;
|
| 54 |
+
charts?: ChartPoint[];
|
| 55 |
+
metrics?: JobMetrics;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
export type Extras = {
|
| 59 |
+
datasetLimit: number;
|
| 60 |
+
};
|
frontend/src/vite-env.d.ts
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/// <reference types="vite/client" />
|
frontend/tailwind.config.js
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/** @type {import('tailwindcss').Config} */
|
| 2 |
+
export default {
|
| 3 |
+
content: [
|
| 4 |
+
"./index.html",
|
| 5 |
+
"./src/**/*.{js,ts,jsx,tsx}",
|
| 6 |
+
],
|
| 7 |
+
theme: {
|
| 8 |
+
extend: {},
|
| 9 |
+
},
|
| 10 |
+
plugins: [require('@tailwindcss/forms')],
|
| 11 |
+
}
|
frontend/tsconfig.app.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
| 4 |
+
"target": "ES2022",
|
| 5 |
+
"useDefineForClassFields": true,
|
| 6 |
+
"lib": ["ES2022", "DOM", "DOM.Iterable"],
|
| 7 |
+
"module": "ESNext",
|
| 8 |
+
"skipLibCheck": true,
|
| 9 |
+
|
| 10 |
+
/* Bundler mode */
|
| 11 |
+
"moduleResolution": "bundler",
|
| 12 |
+
"allowImportingTsExtensions": true,
|
| 13 |
+
"verbatimModuleSyntax": true,
|
| 14 |
+
"moduleDetection": "force",
|
| 15 |
+
"noEmit": true,
|
| 16 |
+
"jsx": "react-jsx",
|
| 17 |
+
|
| 18 |
+
/* Linting */
|
| 19 |
+
"strict": true,
|
| 20 |
+
"noUnusedLocals": false,
|
| 21 |
+
"noUnusedParameters": false,
|
| 22 |
+
"erasableSyntaxOnly": true,
|
| 23 |
+
"noFallthroughCasesInSwitch": true,
|
| 24 |
+
"noUncheckedSideEffectImports": true
|
| 25 |
+
},
|
| 26 |
+
"include": ["src"]
|
| 27 |
+
}
|
frontend/tsconfig.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"files": [],
|
| 3 |
+
"references": [
|
| 4 |
+
{ "path": "./tsconfig.app.json" },
|
| 5 |
+
{ "path": "./tsconfig.node.json" }
|
| 6 |
+
]
|
| 7 |
+
}
|
| 8 |
+
|
frontend/tsconfig.node.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
| 4 |
+
"target": "ES2023",
|
| 5 |
+
"lib": ["ES2023"],
|
| 6 |
+
"module": "ESNext",
|
| 7 |
+
"skipLibCheck": true,
|
| 8 |
+
|
| 9 |
+
/* Bundler mode */
|
| 10 |
+
"moduleResolution": "bundler",
|
| 11 |
+
"allowImportingTsExtensions": true,
|
| 12 |
+
"verbatimModuleSyntax": true,
|
| 13 |
+
"moduleDetection": "force",
|
| 14 |
+
"noEmit": true,
|
| 15 |
+
|
| 16 |
+
/* Linting */
|
| 17 |
+
"strict": true,
|
| 18 |
+
"noUnusedLocals": true,
|
| 19 |
+
"noUnusedParameters": true,
|
| 20 |
+
"erasableSyntaxOnly": true,
|
| 21 |
+
"noFallthroughCasesInSwitch": true,
|
| 22 |
+
"noUncheckedSideEffectImports": true
|
| 23 |
+
},
|
| 24 |
+
"include": ["vite.config.ts"]
|
| 25 |
+
}
|
frontend/vite.config.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { defineConfig } from 'vite'
|
| 2 |
+
import react from '@vitejs/plugin-react'
|
| 3 |
+
|
| 4 |
+
// https://vite.dev/config/
|
| 5 |
+
export default defineConfig({
|
| 6 |
+
plugins: [react()],
|
| 7 |
+
})
|
nginx.conf.template
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
worker_processes auto;
|
| 2 |
+
events {
|
| 3 |
+
worker_connections 1024;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
http {
|
| 7 |
+
include /etc/nginx/mime.types;
|
| 8 |
+
default_type application/octet-stream;
|
| 9 |
+
sendfile on;
|
| 10 |
+
keepalive_timeout 65;
|
| 11 |
+
|
| 12 |
+
server {
|
| 13 |
+
listen 7860;
|
| 14 |
+
server_name _;
|
| 15 |
+
root /usr/share/nginx/html;
|
| 16 |
+
index index.html;
|
| 17 |
+
|
| 18 |
+
# 前端單頁應用
|
| 19 |
+
location / {
|
| 20 |
+
try_files $uri $uri/ /index.html;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# 後端 API 反向代理
|
| 24 |
+
location /api/ {
|
| 25 |
+
proxy_pass http://127.0.0.1:5001/;
|
| 26 |
+
proxy_http_version 1.1;
|
| 27 |
+
proxy_set_header Upgrade $http_upgrade;
|
| 28 |
+
proxy_set_header Connection "upgrade";
|
| 29 |
+
proxy_set_header Host $host;
|
| 30 |
+
proxy_set_header X-Real-IP $remote_addr;
|
| 31 |
+
proxy_buffering off;
|
| 32 |
+
|
| 33 |
+
proxy_connect_timeout 3000s;
|
| 34 |
+
proxy_send_timeout 3000s;
|
| 35 |
+
proxy_read_timeout 3000s;
|
| 36 |
+
proxy_redirect off;
|
| 37 |
+
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|