peihsin0715 commited on
Commit
7c447a5
·
1 Parent(s): 842d2fd

Add all project files for HF Spaces deployment

Browse files
Dockerfile ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------- Frontend build ----------
2
+ FROM node:20-bullseye AS fe
3
+ WORKDIR /app/frontend
4
+ COPY frontend/package*.json ./
5
+ RUN npm ci
6
+ COPY frontend/ ./
7
+ RUN npm run build
8
+
9
+ # ---------- Backend build ----------
10
+ FROM python:3.11-slim AS be
11
+ ENV PIP_NO_CACHE_DIR=1 PYTHONUNBUFFERED=1 PIP_PREFER_BINARY=1
12
+ WORKDIR /app
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ build-essential gcc g++ gfortran make pkg-config \
15
+ libopenblas-dev liblapack-dev git \
16
+ && rm -rf /var/lib/apt/lists/*
17
+ RUN python -m pip install --upgrade pip setuptools wheel
18
+ RUN pip install --index-url https://download.pytorch.org/whl/cpu "torch==2.3.1"
19
+ COPY backend/requirements.txt ./backend/requirements.txt
20
+ RUN sed -i 's/^[Tt]orch[[:space:]=<>!].*/# torch pinned separately (CPU)/' backend/requirements.txt || true
21
+ RUN pip install --only-binary=:all: blis || echo "Precompiled blis not available"
22
+ RUN pip install -r backend/requirements.txt || pip install -r backend/requirements.txt --no-deps
23
+ COPY backend/ ./backend/
24
+
25
+ # ---------- Runtime ----------
26
+ FROM python:3.11-slim AS runtime
27
+ ENV PYTHONUNBUFFERED=1 PIP_NO_CACHE_DIR=1 PORT=7860
28
+ WORKDIR /app
29
+ RUN apt-get update && apt-get install -y --no-install-recommends \
30
+ nginx supervisor ca-certificates \
31
+ libgomp1 libopenblas0 \
32
+ && rm -rf /var/lib/apt/lists/*
33
+ COPY --from=fe /app/frontend/dist /usr/share/nginx/html
34
+ COPY --from=be /usr/local /usr/local
35
+ COPY --from=be /app/backend /app/backend
36
+ RUN python -m pip install --no-cache-dir gunicorn
37
+
38
+ COPY nginx.conf.template /etc/nginx/nginx.conf
39
+
40
+ RUN mkdir -p /etc/supervisor/conf.d && \
41
+ printf "[program:api]\n\
42
+ command=gunicorn --workers 2 --threads 8 --timeout 0 --chdir /app/backend -b 0.0.0.0:5001 server:app\n\
43
+ priority=10\nautostart=true\nautorestart=true\n\
44
+ stdout_logfile=/dev/stdout\nstderr_logfile=/dev/stderr\n\
45
+ stdout_logfile_maxbytes=0\nstderr_logfile_maxbytes=0\n\n\
46
+ [program:nginx]\n\
47
+ command=nginx -g \"daemon off;\"\n\
48
+ priority=20\nautostart=true\nautorestart=true\n\
49
+ stdout_logfile=/dev/stdout\nstderr_logfile=/dev/stderr\n\
50
+ stdout_logfile_maxbytes=0\nstderr_logfile_maxbytes=0\n\n\
51
+ [supervisord]\nlogfile=/dev/stdout\nlogfile_maxbytes=0\nnodaemon=true\nuser=root\n" \
52
+ > /etc/supervisor/conf.d/app.conf
53
+
54
+ EXPOSE 7860
55
+ CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/app.conf"]
backend/requirements.txt ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.10.1
2
+ aiohttp==3.9.1
3
+ aiosignal==1.3.1
4
+ alabaster>=0.7,<0.8
5
+ annotated-types==0.6.0
6
+ anyio==4.2.0
7
+ arrow==1.3.0
8
+ attrs==23.2.0
9
+ babel==2.17.0
10
+ beautifulsoup4==4.13.3
11
+ bibtexparser==1.4.3
12
+ boto3==1.36.14
13
+ bs4==0.0.2
14
+ catalogue==2.0.10
15
+ certifi==2023.11.17
16
+ charset-normalizer==3.3.2
17
+ click==8.1.7
18
+ cloudpathlib==0.20.0
19
+ colorama==0.4.6
20
+ confection==0.1.5
21
+ contourpy==1.3.1
22
+ cycler==0.12.1
23
+ cymem==2.0.11
24
+ dataclasses-json==0.6.3
25
+ datasets==2.18.0
26
+ defusedxml==0.7.1
27
+ Deprecated==1.2.18
28
+ dill==0.3.8
29
+ distro==1.9.0
30
+ fake-useragent==2.0.3
31
+ fastapi==0.109.0
32
+ filelock==3.17.0
33
+ Flask==2.0.3
34
+ Flask-Cors==3.0.10
35
+ fonttools==4.56.0
36
+ fpdf2==2.8.3
37
+ free-proxy==1.1.3
38
+ frozenlist==1.4.1
39
+ fsspec==2024.2.0
40
+ gensim==4.3.3
41
+ h11==0.14.0
42
+ hdbscan==0.8.40
43
+ hf-xet==1.1.9
44
+ httpcore==1.0.2
45
+ httpx==0.26.0
46
+ huggingface-hub==0.34.4
47
+ idna==3.6
48
+ imagesize==1.4.1
49
+ itsdangerous==2.2.0
50
+ Jinja2==3.0.3
51
+ jmespath==1.0.1
52
+ joblib==1.4.2
53
+ jsonpatch==1.33
54
+ jsonpointer==2.4
55
+ kiwisolver==1.4.8
56
+ langchain==0.1.1
57
+ langchain-community==0.0.13
58
+ langchain-core==0.1.13
59
+ langchain-openai==0.0.3
60
+ langcodes==3.5.0
61
+ langserve==0.0.39
62
+ langsmith==0.0.83
63
+ language_data==1.3.0
64
+ lxml==5.3.1
65
+ marisa-trie==1.2.1
66
+ markdown-it-py==3.0.0
67
+ MarkupSafe==3.0.2
68
+ marshmallow==3.20.2
69
+ matplotlib==3.10.0
70
+ mdurl==0.1.2
71
+ mpmath==1.3.0
72
+ multidict==6.0.4
73
+ multiprocess==0.70.16
74
+ murmurhash==1.0.12
75
+ mypy-extensions==1.0.0
76
+ narwhals==1.39.0
77
+ networkx==3.4.2
78
+ nltk==3.8.1
79
+ numpy==1.26.4
80
+ openai==1.9.0
81
+ orjson==3.9.12
82
+ outcome==1.3.0.post0
83
+ packaging==23.2
84
+ pandas==2.1.4
85
+ Pillow==9.5.0
86
+ plotly==6.0.1
87
+ preshed==3.0.9
88
+ psutil==7.0.0
89
+ pyarrow==14.0.2
90
+ pyarrow-hotfix==0.7
91
+ pyasn1==0.6.1
92
+ pydantic==2.5.3
93
+ pydantic_core==2.14.6
94
+ Pygments==2.19.1
95
+ pyparsing==3.2.1
96
+ PySocks==1.7.1
97
+ python-dateutil==2.9.0.post0
98
+ python-dotenv==1.0.0
99
+ pytz==2025.1
100
+ PyYAML==6.0.1
101
+ rank-bm25==0.2.2
102
+ regex==2023.12.25
103
+ requests==2.32.3
104
+ rich==13.9.4
105
+ roman-numerals-py==3.1.0
106
+ rsa==4.7.2
107
+ s3transfer==0.11.2
108
+ safetensors==0.5.2
109
+ SAGEDbias @ https://github.com/holistic-ai/SAGED-Bias/archive/8d5664387c58d94ffd10667c40493a5e460eaac6.zip
110
+ scholarly==1.7.11
111
+ scikit-learn==1.6.1
112
+ scipy==1.13.1
113
+ selenium==4.29.0
114
+ sentence-transformers==3.4.1
115
+ shellingham==1.5.4
116
+ six==1.17.0
117
+ smart-open==7.1.0
118
+ sniffio==1.3.0
119
+ snowballstemmer==2.2.0
120
+ sortedcontainers==2.4.0
121
+ soupsieve==2.6
122
+ spacy==3.8.7
123
+ spacy-legacy==3.0.12
124
+ spacy-loggers==1.0.5
125
+ Sphinx==7.2.6
126
+ sphinx-rtd-theme==3.0.2
127
+ sphinxcontrib-applehelp==2.0.0
128
+ sphinxcontrib-devhelp==2.0.0
129
+ sphinxcontrib-htmlhelp==2.1.0
130
+ sphinxcontrib-jquery==4.1
131
+ sphinxcontrib-jsmath==1.0.1
132
+ sphinxcontrib-qthelp==2.0.0
133
+ sphinxcontrib-serializinghtml==2.0.0
134
+ SQLAlchemy==2.0.25
135
+ srsly==2.5.1
136
+ starlette==0.35.1
137
+ sympy==1.13.1
138
+ tenacity==8.2.3
139
+ thinc==8.3.4
140
+ threadpoolctl==3.5.0
141
+ tiktoken==0.5.2
142
+ tokenizers==0.22.0
143
+ tqdm==4.67.1
144
+ transformers==4.56.1
145
+ trio==0.29.0
146
+ trio-websocket==0.12.2
147
+ typer==0.15.1
148
+ types-python-dateutil==2.9.0.20241206
149
+ typing-inspect==0.9.0
150
+ typing_extensions==4.12.2
151
+ tzdata==2025.1
152
+ urllib3==2.1.0
153
+ uvicorn==0.26.0
154
+ wasabi==1.1.3
155
+ weasel==0.4.1
156
+ websocket-client==1.8.0
157
+ Werkzeug==2.0.3
158
+ Wikipedia-API==0.7.3
159
+ wrapt==1.17.2
160
+ wsproto==1.2.0
161
+ xgboost==3.0.0
162
+ xxhash==3.5.0
163
+ yarl==1.9.4
164
+ seaborn
backend/server.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file, send_from_directory
2
+ from flask_cors import CORS
3
+ import pandas as pd
4
+ import torch
5
+ import os
6
+ from datetime import datetime
7
+ from tqdm import tqdm
8
+ import logging
9
+ from functools import lru_cache
10
+ from typing import Optional, List, Dict, Any
11
+ from utils.utils import _ensure_plot_saved
12
+
13
+ os.environ["MPLBACKEND"] = "Agg"
14
+ os.environ["QT_QPA_PLATFORM"] = "offscreen"
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ from utils.sampling import rank_sample
19
+ try:
20
+ from transformers import TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
21
+ print("✓ transformers training components imported")
22
+ except Exception as e:
23
+ print(f"✗ transformers training import failed: {e}")
24
+ def finetune(*args, **kwargs):
25
+ print("Warning: Transformers training components not available, skipping fine-tuning")
26
+ return None
27
+
28
+ # 🤗 datasets
29
+ try:
30
+ from datasets import (
31
+ load_dataset,
32
+ load_dataset_builder,
33
+ get_dataset_config_names,
34
+ get_dataset_split_names,
35
+ Features,
36
+ )
37
+ print("✓ datasets imported")
38
+ except Exception as e:
39
+ print(f"✗ datasets import failed: {e}")
40
+ raise
41
+
42
+ from utils.utils import (
43
+ generate_topk_samples,
44
+ evaluate_generated_outputs,
45
+ load_model_and_tokenizer,
46
+ generate_counterfactual_augmentations,
47
+ )
48
+ print("✓ utils imported")
49
+
50
+ app = Flask(__name__)
51
+ CORS(app)
52
+
53
+ _MODELS = {}
54
+ _CURRENT_DATASET = None
55
+ _GENERATION_RESULTS = None
56
+
57
+ @app.route('/data/<path:filename>')
58
+ def serve_data(filename):
59
+ import os
60
+ from flask import Response
61
+
62
+ print(f"[Static] Requested file: {filename}")
63
+
64
+ data_dir = os.path.abspath('data')
65
+ file_path = os.path.join(data_dir, filename)
66
+
67
+ print(f"[Static] Full path: {file_path}")
68
+ print(f"[Static] File exists: {os.path.exists(file_path)}")
69
+
70
+ if not os.path.exists(file_path):
71
+ return "File not found", 404
72
+
73
+ try:
74
+ with open(file_path, 'rb') as f:
75
+ file_data = f.read()
76
+
77
+ if filename.endswith('.png'):
78
+ mimetype = 'image/png'
79
+ elif filename.endswith('.jpg') or filename.endswith('.jpeg'):
80
+ mimetype = 'image/jpeg'
81
+ elif filename.endswith('.csv'):
82
+ mimetype = 'text/csv'
83
+ else:
84
+ mimetype = 'application/octet-stream'
85
+
86
+ print(f"[Static] Serving {len(file_data)} bytes as {mimetype}")
87
+
88
+ return Response(file_data, mimetype=mimetype)
89
+
90
+ except Exception as e:
91
+ print(f"[Static] Error reading file: {e}")
92
+ return f"Error reading file: {str(e)}", 500
93
+
94
+ @app.route('/debug/files', methods=['GET'])
95
+ def debug_files():
96
+ try:
97
+ data_dir = os.path.abspath('data')
98
+ if not os.path.exists(data_dir):
99
+ return jsonify({"error": "Data directory not found", "path": data_dir})
100
+
101
+ files = []
102
+ for f in os.listdir(data_dir):
103
+ file_path = os.path.join(data_dir, f)
104
+ files.append({
105
+ "name": f,
106
+ "path": file_path,
107
+ "exists": os.path.exists(file_path),
108
+ "size": os.path.getsize(file_path) if os.path.exists(file_path) else 0
109
+ })
110
+
111
+ return jsonify({
112
+ "data_directory": data_dir,
113
+ "files": files
114
+ })
115
+ except Exception as e:
116
+ return jsonify({"error": str(e)})
117
+
118
+ def get_model(model_name: str):
119
+ if model_name in _MODELS:
120
+ print(f"Using cached model: {model_name}")
121
+ return _MODELS[model_name]
122
+ print(f"Loading new model: {model_name}")
123
+ tokenizer, model, device = load_model_and_tokenizer(model_name)
124
+ _MODELS[model_name] = (tokenizer, model, device)
125
+ return tokenizer, model, device
126
+
127
+
128
+ @app.route('/health', methods=['GET'])
129
+ def health_check():
130
+ return jsonify({
131
+ "status": "healthy",
132
+ "timestamp": datetime.now().isoformat(),
133
+ "loaded_models": list(_MODELS.keys()),
134
+ "dataset_loaded": _CURRENT_DATASET is not None,
135
+ "generation_results_available": _GENERATION_RESULTS is not None
136
+ })
137
+
138
+
139
+ def _flatten_features(feats, prefix: str = "") -> List[str]:
140
+ cols: List[str] = []
141
+ try:
142
+ items = feats.items() if isinstance(feats, (Features, dict)) else feats.items()
143
+ except Exception:
144
+ try:
145
+ return list(feats.keys())
146
+ except Exception:
147
+ return cols
148
+ for name, sub in items:
149
+ full = f"{prefix}.{name}" if prefix else name
150
+ try:
151
+ if isinstance(sub, (Features, dict)):
152
+ cols += _flatten_features(sub, prefix=full)
153
+ else:
154
+ cols.append(full)
155
+ except Exception:
156
+ cols.append(full)
157
+ return cols
158
+
159
+ @lru_cache(maxsize=256)
160
+ def _get_dataset_fields_cached(dataset_id: str, config: Optional[str], split: str) -> List[str]:
161
+ try:
162
+ builder = load_dataset_builder(dataset_id, name=config)
163
+ feats = builder.info.features
164
+ fields = _flatten_features(feats)
165
+ return sorted(set(fields))
166
+ except Exception as e_builder:
167
+ try:
168
+ ds = load_dataset(dataset_id, name=config, split=split, streaming=True)
169
+ first = next(iter(ds.take(1)), None)
170
+ if first is None:
171
+ return []
172
+ fields = list(first.keys())
173
+ return sorted(set(fields))
174
+ except Exception as e_stream:
175
+ raise RuntimeError(f"builder_error={e_builder}; streaming_error={e_stream}")
176
+
177
+ @app.route('/dataset/fields', methods=['GET'])
178
+ def dataset_fields():
179
+ dataset_id = request.args.get('id')
180
+ cfg = request.args.get('config')
181
+ split = request.args.get('split', 'train')
182
+ if not dataset_id:
183
+ return jsonify({"error": "Missing required query param 'id'"}), 400
184
+ try:
185
+ fields = _get_dataset_fields_cached(dataset_id, cfg, split)
186
+ return jsonify({
187
+ "fields": fields,
188
+ "datasetId": dataset_id,
189
+ "config": cfg,
190
+ "split": split,
191
+ "source": "huggingface-builder" if fields else "unknown"
192
+ })
193
+ except Exception as e:
194
+ return jsonify({
195
+ "error": "Failed to fetch dataset fields",
196
+ "datasetId": dataset_id,
197
+ "config": cfg,
198
+ "split": split,
199
+ "detail": str(e)
200
+ }), 400
201
+
202
+ @app.route('/dataset/meta', methods=['GET'])
203
+ def dataset_meta():
204
+ dataset_id = request.args.get('id')
205
+ if not dataset_id:
206
+ return jsonify({"error": "Missing required query param 'id'"}), 400
207
+ try:
208
+ configs = get_dataset_config_names(dataset_id)
209
+ except Exception as e:
210
+ configs = []
211
+ logging.warning(f"get_dataset_config_names failed for {dataset_id}: {e}")
212
+ splits: List[str] = []
213
+ try:
214
+ if configs:
215
+ try:
216
+ b0 = load_dataset_builder(dataset_id, name=configs[0])
217
+ splits = sorted(list(b0.info.splits) or [])
218
+ except Exception:
219
+ splits = get_dataset_split_names(dataset_id, configs[0])
220
+ else:
221
+ try:
222
+ b = load_dataset_builder(dataset_id)
223
+ splits = sorted(list(b.info.splits) or [])
224
+ except Exception:
225
+ splits = get_dataset_split_names(dataset_id)
226
+ except Exception as e:
227
+ logging.warning(f"get splits failed for {dataset_id}: {e}")
228
+ splits = []
229
+ return jsonify({
230
+ "datasetId": dataset_id,
231
+ "configs": configs,
232
+ "splits": splits
233
+ })
234
+
235
+ @app.route('/dataset/field-stats', methods=['GET'])
236
+ def dataset_field_stats():
237
+ dataset_id = request.args.get('id')
238
+ cfg = request.args.get('config')
239
+ split = request.args.get('split', 'train')
240
+ field = request.args.get('field')
241
+ subfield = request.args.get('subfield')
242
+ if not dataset_id or not field:
243
+ return jsonify({"error": "Missing required query params 'id' or 'field'"}), 400
244
+ try:
245
+ ds = load_dataset(dataset_id, name=cfg, split=split, streaming=True)
246
+ max_rows = 50000
247
+ counter: Dict[str, Any] = {}
248
+ print(f"[field-stats] Computing stats for '{field}'" + (f" → '{subfield}'" if subfield else ""))
249
+ for i, row in enumerate(ds):
250
+ if i >= max_rows:
251
+ break
252
+ main_val = row.get(field)
253
+ if main_val is None:
254
+ continue
255
+ if subfield:
256
+ sub_val = row.get(subfield)
257
+ if sub_val is None:
258
+ continue
259
+ counter.setdefault(main_val, {})
260
+ counter[main_val][sub_val] = counter[main_val].get(sub_val, 0) + 1
261
+ else:
262
+ counter[main_val] = counter.get(main_val, 0) + 1
263
+ return jsonify({
264
+ "field": field,
265
+ "subfield": subfield,
266
+ "datasetId": dataset_id,
267
+ "config": cfg,
268
+ "split": split,
269
+ "counts": counter
270
+ })
271
+ except Exception as e:
272
+ return jsonify({
273
+ "error": f"Failed to compute field stats: {str(e)}",
274
+ "datasetId": dataset_id,
275
+ "config": cfg,
276
+ "split": split,
277
+ "field": field,
278
+ "subfield": subfield
279
+ }), 500
280
+
281
+ def _parse_selected_groups_from_config(config: dict) -> List[str]:
282
+ raw = config.get('selectedCfFields', []) or []
283
+ out: List[str] = []
284
+ for s in raw:
285
+ s = (s or "").strip()
286
+ if not s:
287
+ continue
288
+ if "/" in s:
289
+ out.append(s.split("/")[-1])
290
+ else:
291
+ out.append(s)
292
+ seen = set()
293
+ uniq = []
294
+ for x in out:
295
+ if x not in seen:
296
+ uniq.append(x)
297
+ seen.add(x)
298
+ return uniq
299
+
300
+ def stratified_sample_by_category(df: pd.DataFrame, category_col: str, groups: List[str], total_n: Optional[int]) -> pd.DataFrame:
301
+ if total_n is None or total_n <= 0:
302
+ return df
303
+
304
+ groups_present = [g for g in groups if g in df[category_col].unique()]
305
+ if not groups_present:
306
+ return df.sample(n=min(total_n, len(df)), random_state=42)
307
+
308
+ base_each = max(1, total_n // max(1, len(groups_present)))
309
+ remainder = max(0, total_n - base_each * len(groups_present))
310
+
311
+ parts = []
312
+ for g in groups_present:
313
+ gdf = df[df[category_col] == g]
314
+ need = min(base_each, len(gdf))
315
+ if need > 0:
316
+ parts.append(gdf.sample(n=need, random_state=42))
317
+
318
+ i = 0
319
+ while remainder > 0 and len(df) > 0:
320
+ g = groups_present[i % len(groups_present)]
321
+ gdf = df[df[category_col] == g]
322
+ if len(gdf) > 0:
323
+ parts.append(gdf.sample(n=1, replace=(len(gdf) < 1), random_state=42 + remainder))
324
+ remainder -= 1
325
+ i += 1
326
+
327
+ out = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=df.columns)
328
+ if len(out) < total_n and len(df) > len(out):
329
+ rest = min(total_n - len(out), len(df) - len(out))
330
+ pool = df.drop(out.index, errors="ignore")
331
+ if len(pool) > 0 and rest > 0:
332
+ out = pd.concat([out, pool.sample(n=min(rest, len(pool)), random_state=777)], ignore_index=True)
333
+ return out
334
+
335
+ def _pairwise_max_abs_diff(means: Dict[str, float]) -> float:
336
+ from itertools import combinations
337
+ keys = list(means.keys())
338
+ if len(keys) < 2:
339
+ return 0.0
340
+ diffs = [abs(means[a] - means[b]) for a, b in combinations(keys, 2)]
341
+ return float(max(diffs)) if diffs else 0.0
342
+
343
+ def _mean_by_cat(df: pd.DataFrame, cats: List[str], score_col: str = "sentiment_score") -> Dict[str, float]:
344
+ out: Dict[str, float] = {}
345
+ for c in cats:
346
+ sub = df[df["category"] == c]
347
+ if len(sub) > 0:
348
+ out[c] = float(sub[score_col].mean())
349
+ return out
350
+
351
+ @app.route('/pipeline', methods=['POST'])
352
+ def run_pipeline():
353
+ """Run the complete pipeline with frontend JobConfig format"""
354
+ data = request.get_json() or {}
355
+ config = data.get('config', data) or {}
356
+ print("[DEBUG] Received config:", config)
357
+
358
+ dataset_id = config.get('dataset') or "AmazonScience/bold"
359
+ model_name = config.get('languageModel', 'openai-community/gpt2')
360
+ top_k = int(config.get('k', 5))
361
+ dataset_limit_raw = config.get('datasetLimit')
362
+ dataset_limit = int(dataset_limit_raw) if dataset_limit_raw is not None else None
363
+ num_cf_per_row = int(config.get('numCounterfactuals') or 3)
364
+ tau = float(config.get('tau', 0.1))
365
+ iterations = int(config.get('iterations', 1000))
366
+ metric_target = config.get('metrictarget')
367
+
368
+ try:
369
+ results = {}
370
+ global _CURRENT_DATASET, _GENERATION_RESULTS
371
+
372
+ print("Pipeline Step 1: Loading data...")
373
+ ds = load_dataset(dataset_id, split="train")
374
+ df_full = pd.DataFrame(ds)[["domain", "name", "category", "prompts", "wikipedia"]].copy()
375
+
376
+ selected_groups = _parse_selected_groups_from_config(config)
377
+ present_all = sorted(df_full["category"].dropna().unique().tolist())
378
+
379
+ if selected_groups:
380
+ selected_groups = [g for g in selected_groups if g in present_all]
381
+ if len(selected_groups) < 2:
382
+ print(f"[Filter] Requested groups not enough in dataset (have {selected_groups}); fallback to ALL categories")
383
+ selected_groups = []
384
+ else:
385
+ print("[Filter] No groups requested from frontend; will use categories present after generation.")
386
+
387
+ df_pool = df_full[df_full["category"].isin(selected_groups)].copy() if selected_groups else df_full.copy()
388
+
389
+ df = stratified_sample_by_category(
390
+ df=df_pool,
391
+ category_col="category",
392
+ groups=selected_groups if selected_groups else sorted(df_pool["category"].unique().tolist()),
393
+ total_n=dataset_limit
394
+ )
395
+
396
+ print(f"[Pool] pool_size={len(df_pool)}, sampled={len(df)}")
397
+ print(f"[Pool] categories in pool: {sorted(df_pool['category'].unique().tolist())}")
398
+ print(f"[Pool] categories in sample: {sorted(df['category'].unique().tolist())}")
399
+
400
+ _CURRENT_DATASET = df
401
+ results['data_loaded'] = len(df)
402
+ print(f"Dataset loaded: {len(df)} rows")
403
+
404
+ print("Pipeline Step 2: Loading model...")
405
+ tokenizer, model, device = get_model(model_name)
406
+ results['model_loaded'] = model_name
407
+
408
+ print(f"Pipeline Step 3: Generating samples for {len(df)} entries...")
409
+ generation_results = generate_topk_samples(model, _CURRENT_DATASET, tokenizer, device, top_k=top_k)
410
+ task = config.get('classificationTask', 'sentiment')
411
+ tox_choice = config.get('toxicityModelChoice', 'detoxify')
412
+
413
+ evaluated_results = evaluate_generated_outputs(
414
+ generation_results, device,
415
+ task=task,
416
+ toxicity_model_choice=tox_choice
417
+ )
418
+ _GENERATION_RESULTS = evaluated_results
419
+
420
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
421
+ os.makedirs("data", exist_ok=True)
422
+ output_file = f"data/pipeline_generation_{timestamp}.csv"
423
+ evaluated_results.to_csv(output_file, index=False)
424
+ results['generation_file'] = output_file
425
+ results['generation_samples'] = len(evaluated_results)
426
+ print("Pipeline Step 3.5: Counterfactual augmentation...")
427
+ augmented_results = generate_counterfactual_augmentations(
428
+ evaluated_results,
429
+ text_col="generated",
430
+ name_col="name",
431
+ category_col="category",
432
+ num_cf_per_row=num_cf_per_row
433
+ )
434
+ augmented_file = f"data/pipeline_generation_cf_augmented_{timestamp}.csv"
435
+ augmented_results.to_csv(augmented_file, index=False)
436
+ results['counterfactual_file'] = augmented_file
437
+ results['counterfactual_added'] = len(augmented_results) - len(evaluated_results)
438
+ results['counterfactual_total'] = len(augmented_results)
439
+
440
+ present_after_gen = sorted(evaluated_results["category"].dropna().unique().tolist())
441
+ if not selected_groups:
442
+ selected_groups_used = present_after_gen
443
+ else:
444
+ selected_groups_used = [g for g in selected_groups if g in present_after_gen]
445
+ if len(selected_groups_used) < 2:
446
+ print(f"[Sampling] After generation only {selected_groups_used} present; expanding to all present categories")
447
+ selected_groups_used = present_after_gen
448
+
449
+ print(f"[Sampling] Using groups: {selected_groups_used}")
450
+
451
+ print("Debug: Checking data before sampling...")
452
+ print(f"Total evaluated results: {len(evaluated_results)}")
453
+ print(f"Categories in data: {present_after_gen}")
454
+ print(f"Names in data: {evaluated_results['name'].unique()}")
455
+
456
+ for cat in selected_groups_used:
457
+ cat_count = int((evaluated_results["category"] == cat).sum())
458
+ print(f"Category '{cat}': {cat_count} samples")
459
+
460
+ print(f"Pipeline Step 4: Rank sampling on original evaluated results...(iterations={iterations}, temp={tau})")
461
+ try:
462
+ best_sent_subset = rank_sample(evaluated_results, num_samples=iterations, temp=tau, target_value=metric_target)
463
+ except (ValueError, IndexError) as e:
464
+ print(f"Sampling failed: {e}")
465
+ mid_point = len(evaluated_results) // 2
466
+ best_sent_subset = evaluated_results.iloc[:mid_point].copy()
467
+
468
+ sent_file = f"data/pipeline_sent_subset_{timestamp}.csv"
469
+ best_sent_subset.to_csv(sent_file, index=False)
470
+
471
+ print(f"Pipeline Step 5: Rank sampling on CF-augmented results...(iterations={iterations}, temp={tau})")
472
+ try:
473
+ cf_best_sent_subset = rank_sample(augmented_results, num_samples=iterations, temp=tau, target_value=metric_target)
474
+ except (ValueError, IndexError) as e:
475
+ print(f"CF Sampling failed: {e}")
476
+ mid_point = len(augmented_results) // 2
477
+ cf_best_sent_subset = augmented_results.iloc[:mid_point].copy()
478
+
479
+ cf_sent_file = f"data/pipeline_cf_sent_subset_{timestamp}.csv"
480
+ cf_best_sent_subset.to_csv(cf_sent_file, index=False)
481
+
482
+ orig_means = _mean_by_cat(best_sent_subset, selected_groups_used)
483
+ final_mean_diff = _pairwise_max_abs_diff(orig_means)
484
+
485
+ cf_means = _mean_by_cat(cf_best_sent_subset, selected_groups_used)
486
+ cf_final_mean_diff = _pairwise_max_abs_diff(cf_means)
487
+
488
+ print("Pipeline Step 6: Plotting distributions...")
489
+
490
+ def _safe(s: str) -> str:
491
+ import re
492
+ return re.sub(r"[^A-Za-z0-9_.-]+", "_", s)
493
+
494
+ orig_sent_title = _safe(f"{timestamp}_original_distribution")
495
+ cf_sent_title = _safe(f"{timestamp}_cf_distribution")
496
+
497
+ score_col = None
498
+ for c in [
499
+ "sentiment_score", "regard_score", "toxicity_score",
500
+ "stereotype_gender_score", "stereotype_religion_score",
501
+ "stereotype_profession_score", "stereotype_race_score",
502
+ "personality_score",
503
+ ]:
504
+ if c in best_sent_subset.columns:
505
+ score_col = c
506
+ break
507
+ if score_col is None:
508
+ raise KeyError(f"No score column found. Available: {list(best_sent_subset.columns)}")
509
+
510
+ orig_path = _ensure_plot_saved(
511
+ best_sent_subset, score_col, orig_sent_title,
512
+ group_col="category", target=metric_target
513
+ )
514
+ cf_path = _ensure_plot_saved(
515
+ cf_best_sent_subset, score_col, cf_sent_title,
516
+ group_col="category", target=metric_target
517
+ )
518
+ print("[Plot check exists]", orig_path, os.path.exists(orig_path))
519
+ print("[Plot check exists]", cf_path, os.path.exists(cf_path))
520
+
521
+ results['plots'] = {
522
+ 'original_sentiment': f"/data/{orig_sent_title}.png",
523
+ 'counterfactual_sentiment': f"/data/{cf_sent_title}.png",
524
+ }
525
+
526
+ print("[Plot urls]", results['plots'])
527
+
528
+ if config.get("enableFineTuning"):
529
+ print("Pipeline Step 7: Fine-tuning enabled, starting training...")
530
+
531
+ ft_cfg = config.get("finetuneParams", {}) or {}
532
+ epochs = int(ft_cfg.get("epochs", 3))
533
+ batch_size = int(ft_cfg.get("batchSize", 8))
534
+ lr = float(ft_cfg.get("learningRate", 5e-5))
535
+
536
+ input_csv = augmented_file
537
+ ft_output_dir = f"data/ft_{timestamp}"
538
+ os.makedirs(ft_output_dir, exist_ok=True)
539
+
540
+ try:
541
+ from utils.finetune import finetune_gpt2_from_csv
542
+ finetune_gpt2_from_csv(
543
+ csv_path=input_csv,
544
+ output_dir=ft_output_dir,
545
+ epochs=epochs,
546
+ batch_size=batch_size,
547
+ lr=lr
548
+ )
549
+ print(f"[Fine-tune] Saved fine-tuned model to {ft_output_dir}")
550
+ results["finetuned_model_dir"] = ft_output_dir
551
+ zip_base = f"data/ft_{timestamp}"
552
+ import shutil
553
+ zip_path = shutil.make_archive(zip_base, 'zip', ft_output_dir)
554
+ results["finetuned_model_zip"] = f"/data/{os.path.basename(zip_path)}"
555
+ except Exception as fe:
556
+ print(f"[Fine-tune] Failed: {fe}")
557
+ results["finetuned_model_error"] = str(fe)
558
+
559
+
560
+ results.update({
561
+ 'sampling_method': 'rank_sentiment_only',
562
+ 'used_groups': selected_groups_used,
563
+ 'sentiment_subset_file': sent_file,
564
+ 'cf_sentiment_subset_file': cf_sent_file,
565
+ 'sentiment_subset_size': len(best_sent_subset),
566
+ 'cf_sentiment_subset_size': len(cf_best_sent_subset),
567
+ 'config_used': config,
568
+ 'metrics': {
569
+ 'finalMeanDiff': final_mean_diff,
570
+ 'cfFinalMeanDiff': cf_final_mean_diff,
571
+ 'reductionPct': (0.0 if final_mean_diff == 0 else max(0.0, (final_mean_diff - cf_final_mean_diff) / abs(final_mean_diff) * 100.0)),
572
+ 'stableCoverage': 100.0
573
+ }
574
+ })
575
+
576
+ return jsonify({
577
+ "status": "success",
578
+ "message": "Complete pipeline executed successfully (with counterfactual augmentation)",
579
+ "results": results,
580
+ "timestamp": timestamp
581
+ })
582
+
583
+ except Exception as e:
584
+ print(f"Error in pipeline: {str(e)}")
585
+ return jsonify({
586
+ "status": "error",
587
+ "message": f"Pipeline failed: {str(e)}"
588
+ }), 500
589
+
590
+
591
+ if __name__ == '__main__':
592
+ os.makedirs("data", exist_ok=True)
593
+ print("Starting minimal Flask server...")
594
+ print("Available endpoints:")
595
+ print(" GET /health - Health check")
596
+ print(" GET /dataset/fields?id=<hf_id>[&config=...][&split=...] - List dataset fields")
597
+ print(" GET /dataset/field-stats?id=...&field=... - Get value distribution of a field")
598
+ print(" GET /dataset/meta?id=<hf_id> - List configs/splits")
599
+ print(" POST /pipeline - Run complete pipeline")
600
+ app.run(host='0.0.0.0', port=5001, debug=True, threaded=True)
backend/utils/finetune.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, random
2
+ import pandas as pd
3
+ import torch
4
+ from typing import Optional
5
+ from transformers import (AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling,
6
+ Trainer, TrainingArguments)
7
+
8
+ try:
9
+ from peft import LoraConfig, get_peft_model, TaskType
10
+ PEFT_AVAILABLE = True
11
+ except Exception:
12
+ PEFT_AVAILABLE = False
13
+
14
+ def build_text_column(df: pd.DataFrame) -> pd.Series:
15
+ cols = [c.lower() for c in df.columns]
16
+ lower_map = {c.lower(): c for c in df.columns}
17
+ if 'text' in cols:
18
+ return df[lower_map['text']].astype(str)
19
+ if 'prompt' in cols and 'generated' in cols:
20
+ pcol = lower_map['prompt']; rcol = lower_map['generated']
21
+ return df.apply(lambda r: f"### Instruction:\n{r[pcol]}\n\n### Response:\n{r[rcol]}\n", axis=1)
22
+
23
+ if 'generated' in cols:
24
+ return df[lower_map['generated']].astype(str)
25
+
26
+ raise ValueError("CSV 缺少可用欄位:請提供 text,或 prompt+generated,或 generated。")
27
+
28
+ def finetune_gpt2_from_csv(
29
+ csv_path: str,
30
+ base_model: str = "gpt2",
31
+ output_dir: str = "data/ft_gpt2_out",
32
+ train_split: float = 0.9,
33
+ epochs: int = 3,
34
+ lr: float = 5e-5,
35
+ batch_size: int = 2,
36
+ use_lora: bool = False,
37
+ lora_r: int = 8,
38
+ lora_alpha: int = 16,
39
+ lora_dropout: float = 0.05,
40
+ seed: int = 42,
41
+ max_length: int = 512,
42
+ ) -> dict:
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ random.seed(seed); torch.manual_seed(seed)
45
+
46
+ df = pd.read_csv(csv_path)
47
+ texts = build_text_column(df).fillna("").tolist()
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
50
+ if tokenizer.pad_token is None:
51
+ tokenizer.pad_token = tokenizer.eos_token
52
+
53
+ model = AutoModelForCausalLM.from_pretrained(base_model)
54
+
55
+ if use_lora:
56
+ if not PEFT_AVAILABLE:
57
+ print("PEFT 未安裝,改為全參數微調")
58
+ else:
59
+ lconf = LoraConfig(
60
+ r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
61
+ task_type=TaskType.CAUSAL_LM, target_modules=["c_attn","c_proj","q_attn"] # 視模型而定
62
+ )
63
+ model = get_peft_model(model, lconf)
64
+
65
+ def tokenize(example_texts):
66
+ return tokenizer(example_texts, truncation=True, max_length=max_length)
67
+
68
+ split_idx = int(len(texts) * train_split)
69
+ train_texts, val_texts = texts[:split_idx], texts[split_idx:] or texts[: max(1, len(texts)//10)]
70
+
71
+ train_enc = tokenize(train_texts)
72
+ val_enc = tokenize(val_texts)
73
+
74
+ collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
75
+
76
+ class SimpleDS(torch.utils.data.Dataset):
77
+ def __init__(self, enc): self.enc = enc
78
+ def __len__(self): return len(self.enc["input_ids"])
79
+ def __getitem__(self, idx):
80
+ return {k: torch.tensor(v[idx]) for k, v in self.enc.items()}
81
+
82
+ train_ds, val_ds = SimpleDS(train_enc), SimpleDS(val_enc)
83
+
84
+ args = TrainingArguments(
85
+ output_dir=output_dir,
86
+ per_device_train_batch_size=batch_size,
87
+ per_device_eval_batch_size=batch_size,
88
+ num_train_epochs=epochs,
89
+ learning_rate=lr,
90
+ warmup_ratio=0.03,
91
+ weight_decay=0.01,
92
+ logging_steps=20,
93
+ eval_strategy="steps",
94
+ eval_steps=100,
95
+ save_strategy="steps",
96
+ save_steps=100,
97
+ save_total_limit=2,
98
+ fp16=torch.cuda.is_available(),
99
+ bf16=torch.cuda.is_bf16_supported() if hasattr(torch.cuda, "is_bf16_supported") else False,
100
+ report_to=[],
101
+ )
102
+
103
+ trainer = Trainer(
104
+ model=model,
105
+ args=args,
106
+ train_dataset=train_ds,
107
+ eval_dataset=val_ds,
108
+ data_collator=collator,
109
+ )
110
+
111
+ trainer.train()
112
+ trainer.save_model(output_dir)
113
+ tokenizer.save_pretrained(output_dir)
114
+
115
+ return {
116
+ "output_dir": output_dir,
117
+ "train_size": len(train_ds),
118
+ "eval_size": len(val_ds),
119
+ "perplexity": math.exp(trainer.evaluate()["eval_loss"]) if "eval_loss" in trainer.evaluate() else None
120
+ }
backend/utils/sampling.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from tqdm import tqdm
4
+ from typing import List, Optional
5
+
6
+ def rank_sample(
7
+ df: pd.DataFrame,
8
+ name_col: str = "name",
9
+ category_col: str = "category",
10
+ sentiment_col: str = "sentiment_score",
11
+ groups: Optional[List[str]] = None,
12
+ num_samples: int = 1000,
13
+ temp: float = 1.0,
14
+ target_value: float = 0.5,
15
+ ) -> pd.DataFrame:
16
+
17
+ df = df.copy()
18
+
19
+ for col in [name_col, category_col, sentiment_col]:
20
+ if col not in df.columns:
21
+ raise ValueError(f"Column '{col}' not found in DataFrame")
22
+
23
+ df = df.dropna(subset=[name_col, category_col, sentiment_col])
24
+
25
+ if groups:
26
+ available_groups = df[category_col].unique()
27
+ valid_groups = [g for g in groups if g in available_groups]
28
+ if len(valid_groups) < 2:
29
+ print(f"Warning: Only {len(valid_groups)} groups available from {groups}")
30
+ groups = None
31
+ else:
32
+ groups = valid_groups
33
+ df = df[df[category_col].isin(groups)].copy()
34
+
35
+ final_groups = df[category_col].unique()
36
+ if len(final_groups) < 2:
37
+ print(f"Error: Only {len(final_groups)} groups in data, need at least 2")
38
+ return df.groupby(name_col).first().reset_index()
39
+
40
+ print(f"Sampling with groups: {sorted(final_groups)}")
41
+ print(f"Target value for deviation calculation: {target_value}")
42
+
43
+ df["sentiment_deviation"] = (df[sentiment_col] - target_value).abs()
44
+ df["sentiment_rank"] = df.groupby(name_col)["sentiment_deviation"].rank(method="first", ascending=True)
45
+
46
+ def softmax_weights(ranks: np.ndarray, temp: float) -> np.ndarray:
47
+ t = float(temp) if temp and temp > 1e-8 else 1e-8
48
+ x = -ranks / t
49
+ x = x - np.max(x)
50
+ exps = np.exp(x)
51
+ s = exps.sum()
52
+ return exps / s if np.isfinite(s) and s > 0 else np.ones_like(exps) / len(exps)
53
+
54
+ def objective_max_pairwise_diff(frame: pd.DataFrame) -> float:
55
+ g = frame.groupby(category_col)[sentiment_col].mean().dropna()
56
+ if len(g) < 2:
57
+ return np.inf
58
+ vals = g.values
59
+ diffs = np.abs(vals[:, None] - vals[None, :])
60
+ return float(np.max(diffs))
61
+
62
+ best_subset = None
63
+ best_obj = np.inf
64
+ valid_samples = 0
65
+
66
+ unique_names = df[name_col].nunique()
67
+ print(f"Total unique names: {unique_names}")
68
+
69
+ for i in tqdm(range(num_samples), desc="Sampling"):
70
+ try:
71
+ sampled_rows = []
72
+
73
+ for name, group in df.groupby(name_col):
74
+ if len(group) == 0:
75
+ continue
76
+
77
+ ranks = group["sentiment_rank"].to_numpy(dtype=float)
78
+ if len(ranks) == 0:
79
+ continue
80
+
81
+ w = softmax_weights(ranks, temp=temp)
82
+ idx = np.random.choice(group.index, p=w)
83
+ sampled_rows.append(df.loc[idx])
84
+
85
+ if len(sampled_rows) == 0:
86
+ continue
87
+
88
+ subset = pd.DataFrame(sampled_rows)
89
+
90
+ subset_groups = subset[category_col].unique()
91
+ if len(subset_groups) < 2:
92
+ continue
93
+
94
+ obj = objective_max_pairwise_diff(subset)
95
+
96
+ if np.isfinite(obj):
97
+ valid_samples += 1
98
+ if obj < best_obj:
99
+ best_obj = obj
100
+ best_subset = subset.copy()
101
+
102
+ if valid_samples % 100 == 0 or valid_samples <= 10:
103
+ group_means = subset.groupby(category_col)[sentiment_col].mean()
104
+ print(f"Sample {valid_samples}: obj={obj:.4f}, groups={dict(group_means)}")
105
+
106
+ except Exception as e:
107
+ print(f"Error in sample {i}: {e}")
108
+ continue
109
+
110
+ print(f"Valid samples: {valid_samples}/{num_samples}")
111
+ print(f"Best objective: {best_obj:.4f}")
112
+
113
+ if best_subset is None or len(best_subset) == 0:
114
+ print("Warning: No valid samples found, returning fallback subset")
115
+ best_subset = df.groupby(name_col).first().reset_index()
116
+
117
+ final_group_counts = best_subset[category_col].value_counts()
118
+ print(f"Final subset group distribution: {dict(final_group_counts)}")
119
+
120
+ return best_subset
backend/utils/utils.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextDataset,DataCollatorForLanguageModeling,Trainer,TrainingArguments
2
+ import torch
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import numpy as np
10
+ import os
11
+ import sys
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ GPT2LMHeadModel,
16
+ GPT2Tokenizer,
17
+ )
18
+
19
+ def load_model_and_tokenizer(model_name: str):
20
+ if torch.cuda.is_available():
21
+ device = torch.device("cuda")
22
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # macOS Apple Silicon
23
+ device = torch.device("mps")
24
+ else:
25
+ device = torch.device("cpu")
26
+
27
+ gpt2_aliases = {"gpt2", "openai-community/gpt2", "holistic-ai/gpt2-EMGSD"}
28
+
29
+ try:
30
+ if model_name in gpt2_aliases:
31
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
32
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
33
+ tokenizer.pad_token = tokenizer.eos_token
34
+ model = GPT2LMHeadModel.from_pretrained(model_name)
35
+ if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
36
+ model.config.pad_token_id = model.config.eos_token_id
37
+ else:
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+ model = AutoModelForCausalLM.from_pretrained(model_name)
42
+ if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
43
+ model.config.pad_token_id = model.config.eos_token_id
44
+
45
+ model.to(device)
46
+ return tokenizer, model, device
47
+ except Exception as e:
48
+ raise RuntimeError(f"Failed to load model '{model_name}': {e}")
49
+
50
+ def finetune(train_texts, tokenizer, model, num_epochs=20, output_dir='./data'):
51
+ train_path = f"data/train.txt"
52
+
53
+ with open(train_path, "w", encoding="utf-8") as f:
54
+ for text in train_texts:
55
+ f.write(text.strip() + "\n")
56
+
57
+ train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_path, block_size=128)
58
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
59
+
60
+ training_args = TrainingArguments(
61
+ output_dir=output_dir,
62
+ overwrite_output_dir=True,
63
+ per_device_train_batch_size=1,
64
+ num_train_epochs=num_epochs,
65
+ save_steps=500,
66
+ save_total_limit=2,
67
+ logging_dir='./logs',
68
+ logging_steps=10,
69
+ report_to="none"
70
+ )
71
+
72
+ trainer = Trainer(
73
+ model=model,
74
+ args=training_args,
75
+ data_collator=data_collator,
76
+ train_dataset=train_dataset,
77
+ )
78
+
79
+ trainer.train()
80
+
81
+ return model
82
+
83
+ def generate_topk_samples(model, df_table, tokenizer, device, top_k=10):
84
+ model.eval()
85
+ flat_results = []
86
+
87
+ df_table["prompts"] = df_table["prompts"].apply(lambda x: x[0] if isinstance(x, list) else x)
88
+
89
+ for idx, row in tqdm(df_table.iterrows(), total=len(df_table), desc="Generating samples"):
90
+ prompt = row["prompts"]
91
+
92
+ inputs = tokenizer(
93
+ prompt,
94
+ return_tensors="pt",
95
+ truncation=True,
96
+ padding=True
97
+ ).to(device)
98
+
99
+ with torch.no_grad():
100
+ outputs = model.generate(
101
+ input_ids=inputs["input_ids"],
102
+ attention_mask=inputs["attention_mask"],
103
+ do_sample=True,
104
+ top_k=top_k,
105
+ max_new_tokens=20,
106
+ top_p=1.0,
107
+ num_return_sequences=top_k,
108
+ pad_token_id=tokenizer.eos_token_id
109
+ )
110
+
111
+ for out in outputs:
112
+ full_text = tokenizer.decode(out, skip_special_tokens=True).strip()
113
+ flat_results.append({
114
+ "domain": row["domain"],
115
+ "name": row["name"],
116
+ "category": row["category"],
117
+ "prompts": prompt,
118
+ "wikipedia": row["wikipedia"],
119
+ "generated": full_text
120
+ })
121
+
122
+ return pd.DataFrame(flat_results)
123
+
124
+
125
+ def evaluate_generated_outputs(
126
+ table: pd.DataFrame,
127
+ device,
128
+ task: str = "sentiment",
129
+ toxicity_model_choice: str = "detoxify",
130
+ text_col: str = "generated",
131
+ ) -> pd.DataFrame:
132
+
133
+ assert text_col in table.columns, f"'{text_col}' not found in table columns"
134
+
135
+ pipe_device = 0 if (isinstance(device, torch.device) and device.type == "cuda") else -1
136
+
137
+ df = table.copy()
138
+ texts = df[text_col].fillna("").astype(str).tolist()
139
+
140
+ task = (task or "sentiment").lower()
141
+
142
+ if task == "sentiment":
143
+ print("Using default sentiment classifier: lxyuan/distilbert-base-multilingual-cased-sentiments-student")
144
+ tok = AutoTokenizer.from_pretrained("lxyuan/distilbert-base-multilingual-cased-sentiments-student")
145
+ mdl = AutoModelForSequenceClassification.from_pretrained("lxyuan/distilbert-base-multilingual-cased-sentiments-student").to(device).eval()
146
+
147
+ scores = []
148
+ for text in tqdm(texts, desc="Scoring (sentiment)"):
149
+ if not text.strip():
150
+ scores.append(0.5)
151
+ continue
152
+ inputs = tok(text, return_tensors="pt", truncation=True, padding=True).to(device)
153
+ with torch.no_grad():
154
+ logits = mdl(**inputs).logits
155
+ probs = F.softmax(logits, dim=1).squeeze(0).tolist()
156
+ val = (probs[2] - probs[0] + 1.0) / 2.0
157
+ scores.append(float(val))
158
+
159
+ df["sentiment_score"] = scores
160
+ return df
161
+
162
+ elif task == "regard":
163
+ print("Using default regard classifier: sasha/regardv3")
164
+ clf = pipeline("text-classification", model="sasha/regardv3", device=pipe_device, top_k=None)
165
+
166
+ def _safe_relabel(text: str):
167
+ try:
168
+ out = clf(text)
169
+ if isinstance(out, list):
170
+ out = out[0] if out else {}
171
+ if isinstance(out, dict) and "label" in out and "score" in out:
172
+ return {out["label"].lower(): float(out["score"])}
173
+ if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
174
+ d = {}
175
+ for item in out:
176
+ d[item["label"].lower()] = float(item["score"])
177
+ return d
178
+ except Exception:
179
+ pass
180
+ return {"positive": 0.5, "negative": 0.5}
181
+
182
+ temp = []
183
+ for text in tqdm(texts, desc="Scoring (regard)"):
184
+ res = _safe_relabel(text)
185
+ pos = float(res.get("positive", 0.5))
186
+ neg = float(res.get("negative", 0.5))
187
+ temp.append(pos - neg + 1.0)
188
+
189
+ df["regard_score"] = temp
190
+ df["sentiment_score"] = df["regard_score"]
191
+ return df
192
+
193
+ elif task == "stereotype":
194
+ print("Using default stereotype classifier: holistic-ai/stereotype-deberta-v3-base-tasksource-nli")
195
+ clf = pipeline("text-classification", model="holistic-ai/stereotype-deberta-v3-base-tasksource-nli", device=pipe_device, top_k=None)
196
+
197
+ def _safe_relabel(text: str):
198
+ try:
199
+ out = clf(text)
200
+ if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
201
+ d = {}
202
+ for item in out:
203
+ d[item["label"].lower()] = float(item["score"])
204
+ return d
205
+ if isinstance(out, dict) and "label" in out:
206
+ return {out["label"].lower(): float(out.get("score", 0.0))}
207
+ except Exception:
208
+ pass
209
+ return {
210
+ "stereotype_gender": 0.0,
211
+ "stereotype_religion": 0.0,
212
+ "stereotype_profession": 0.0,
213
+ "stereotype_race": 0.0,
214
+ }
215
+
216
+ g_list, r_list, p_list, race_list = [], [], [], []
217
+ for text in tqdm(texts, desc="Scoring (stereotype)"):
218
+ d = _safe_relabel(text)
219
+ g_list.append(float(d.get("stereotype_gender", 0.0)))
220
+ r_list.append(float(d.get("stereotype_religion", 0.0)))
221
+ p_list.append(float(d.get("stereotype_profession", 0.0)))
222
+ race_list.append(float(d.get("stereotype_race", 0.0)))
223
+
224
+ df["stereotype_gender_score"] = g_list
225
+ df["stereotype_religion_score"] = r_list
226
+ df["stereotype_profession_score"] = p_list
227
+ df["stereotype_race_score"] = race_list
228
+
229
+ df["sentiment_score"] = df["stereotype_gender_score"]
230
+ return df
231
+
232
+ elif task == "personality":
233
+ print("Using default personality classifier: Navya1602/editpersonality_classifier")
234
+ clf = pipeline("text-classification", model="Navya1602/editpersonality_classifier", device=pipe_device, top_k=None)
235
+
236
+ traits = ["extraversion", "neuroticism", "agreeableness", "conscientiousness", "openness"]
237
+
238
+ def _safe_relabel(text: str):
239
+ try:
240
+ out = clf(text)
241
+ if isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
242
+ d = {}
243
+ for item in out:
244
+ d[item["label"].lower()] = float(item["score"])
245
+ return d
246
+ if isinstance(out, dict) and "label" in out:
247
+ return {out["label"].lower(): float(out.get("score", 0.0))}
248
+ except Exception:
249
+ pass
250
+ return {t: 0.2 for t in traits}
251
+
252
+ cols = {t: [] for t in traits}
253
+ for text in tqdm(texts, desc="Scoring (personality)"):
254
+ d = _safe_relabel(text)
255
+ for t in traits:
256
+ cols[t].append(float(d.get(t, 0.2)))
257
+
258
+ for t in traits:
259
+ df[f"{t}_score"] = cols[t]
260
+
261
+ df["sentiment_score"] = df[[f"{t}_score" for t in traits]].mean(axis=1)
262
+ return df
263
+
264
+ elif task == "toxicity":
265
+ if toxicity_model_choice == "detoxify":
266
+ print("Using unitary/toxic-bert model for toxicity classification")
267
+ clf = pipeline("text-classification", model="unitary/toxic-bert", device=pipe_device, top_k=None)
268
+ def _get_toxic_prob(text: str) -> float:
269
+ try:
270
+ out = clf(text)
271
+ if isinstance(out, list) and out:
272
+ d = {it["label"].lower(): float(it["score"]) for it in out}
273
+ return float(d.get("toxic", d.get("toxic/overall", 0.0)))
274
+ if isinstance(out, dict) and "label" in out:
275
+ return float(out["score"]) if out["label"].lower() == "toxic" else 0.0
276
+ except Exception:
277
+ pass
278
+ return 0.0
279
+ elif toxicity_model_choice == "junglelee":
280
+ print("Using JungleLee/bert-toxic-comment-classification for toxicity classification")
281
+ clf = pipeline("text-classification", model="JungleLee/bert-toxic-comment-classification", device=pipe_device)
282
+ def _get_toxic_prob(text: str) -> float:
283
+ try:
284
+ out = clf(text)
285
+ if isinstance(out, dict):
286
+ lbl = out.get("label", "").lower()
287
+ score = float(out.get("score", 0.0))
288
+ return score if "toxic" in lbl else 0.0
289
+ if isinstance(out, list) and out:
290
+ for it in out:
291
+ if "toxic" in it.get("label", "").lower():
292
+ return float(it.get("score", 0.0))
293
+ except Exception:
294
+ pass
295
+ return 0.0
296
+ else:
297
+ raise ValueError("Invalid toxicity_model_choice. Choose 'detoxify' or 'junglelee'.")
298
+
299
+ tox = []
300
+ for text in tqdm(texts, desc="Scoring (toxicity)"):
301
+ tox.append(_get_toxic_prob(text))
302
+
303
+ df["toxicity_score"] = tox
304
+ df["sentiment_score"] = df["toxicity_score"]
305
+ return df
306
+
307
+ else:
308
+ raise ValueError(f"Unknown task '{task}'. Use one of: sentiment | regard | stereotype | personality | toxicity")
309
+
310
+
311
+ import numpy as np
312
+ import pandas as pd
313
+ from typing import List, Dict, Optional
314
+
315
+ def _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row):
316
+ categories = base_df[category_col].unique().tolist()
317
+ category_names = {}
318
+
319
+ for cat in categories:
320
+ category_names[cat] = base_df[base_df[category_col] == cat][name_col].unique().tolist()
321
+
322
+ print(f"Categories for CF generation: {[f'{cat}({len(names)})' for cat, names in category_names.items()]}")
323
+
324
+ cf_rows = []
325
+ for idx, row in base_df.iterrows():
326
+ original_text = row[text_col]
327
+ original_name = row[name_col]
328
+ original_category = row[category_col]
329
+ original_name_clean = original_name.replace("_", " ")
330
+
331
+ other_categories = [cat for cat in categories if cat != original_category]
332
+
333
+ for target_category in other_categories:
334
+ target_names = category_names[target_category]
335
+
336
+ if len(target_names) == 0:
337
+ continue
338
+
339
+ num_to_sample = min(num_cf_per_row // len(other_categories) + 1, len(target_names))
340
+ if num_to_sample == 0:
341
+ continue
342
+
343
+ sampled_names = np.random.choice(target_names, size=num_to_sample, replace=False)
344
+
345
+ for new_name in sampled_names:
346
+ new_name_clean = new_name.replace("_", " ")
347
+
348
+ new_text = original_text.replace(original_name_clean, new_name_clean, 1)
349
+
350
+ if new_text == original_text:
351
+ original_parts = original_name_clean.split()
352
+ for part in original_parts:
353
+ if len(part) > 2:
354
+ new_text = original_text.replace(part, new_name_clean, 1)
355
+ if new_text != original_text:
356
+ break
357
+
358
+ if new_text == original_text:
359
+ continue
360
+
361
+ new_row = row.copy()
362
+ new_row[name_col] = new_name
363
+ new_row[text_col] = new_text
364
+ new_row[category_col] = target_category
365
+ new_row["original_category"] = original_category
366
+ new_row["cf_type"] = f"{original_category}->{target_category}"
367
+ cf_rows.append(new_row)
368
+
369
+ counterfactual_df = pd.DataFrame(cf_rows)
370
+
371
+ if len(counterfactual_df) > 0:
372
+ cf_stats = counterfactual_df["cf_type"].value_counts()
373
+ print(f"CF generation stats:")
374
+ for cf_type, count in cf_stats.items():
375
+ print(f" {cf_type}: {count}")
376
+
377
+ augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
378
+
379
+ print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
380
+ print(f"Total data len: {len(augmented_df)}")
381
+
382
+ return augmented_df
383
+
384
+ def auto_detect_cf_method(base_df, category_col="category"):
385
+ categories = set(base_df[category_col].unique())
386
+
387
+ if {"American_actors", "American_actresses"}.issubset(categories):
388
+ return "actors_actresses"
389
+ else:
390
+ return "cross_category"
391
+
392
+ class Tee:
393
+ def __init__(self, *streams):
394
+ self.streams = streams
395
+ def write(self, data):
396
+ for stream in self.streams:
397
+ stream.write(data)
398
+ stream.flush()
399
+ def flush(self):
400
+ for stream in self.streams:
401
+ stream.flush()
402
+
403
+ def generate_counterfactual_augmentations(base_df, text_col="generated", name_col="name", category_col="category", num_cf_per_row=3):
404
+ categories = base_df[category_col].unique().tolist()
405
+ category_names = {}
406
+
407
+ for cat in categories:
408
+ category_names[cat] = base_df[base_df[category_col] == cat][name_col].unique().tolist()
409
+
410
+ print(f"Categories for CF generation: {[f'{cat}({len(names)})' for cat, names in category_names.items()]}")
411
+
412
+ if "American_actors" in categories and "American_actresses" in categories:
413
+ return _generate_actors_actresses_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names)
414
+ else:
415
+ return _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names)
416
+
417
+ def _generate_actors_actresses_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names):
418
+ male_names = category_names.get("American_actors", [])
419
+ female_names = category_names.get("American_actresses", [])
420
+
421
+ cf_rows = []
422
+ for idx, row in base_df.iterrows():
423
+ original_text = row[text_col]
424
+ original_name = row[name_col]
425
+ category = row[category_col]
426
+ original_name_clean = original_name.replace("_", " ")
427
+
428
+ if category == "American_actors":
429
+ swap_pool = female_names
430
+ new_category = "American_actresses"
431
+ elif category == "American_actresses":
432
+ swap_pool = male_names
433
+ new_category = "American_actors"
434
+ else:
435
+ continue
436
+
437
+ if len(swap_pool) == 0:
438
+ continue
439
+
440
+ sampled_names = np.random.choice(swap_pool, size=min(num_cf_per_row, len(swap_pool)), replace=False)
441
+
442
+ for new_name in sampled_names:
443
+ new_name_clean = new_name.replace("_", " ")
444
+ new_text = original_text.replace(original_name_clean, new_name_clean, 1)
445
+
446
+ if new_text == original_text:
447
+ continue
448
+
449
+ new_row = row.copy()
450
+ new_row[name_col] = new_name
451
+ new_row[text_col] = new_text
452
+ new_row[category_col] = new_category
453
+ new_row["original_category"] = category
454
+ cf_rows.append(new_row)
455
+
456
+ counterfactual_df = pd.DataFrame(cf_rows)
457
+ augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
458
+
459
+ print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
460
+ print(f"Total data len: {len(augmented_df)}")
461
+ return augmented_df
462
+
463
+ def _generate_cross_category_cf(base_df, text_col, name_col, category_col, num_cf_per_row, category_names):
464
+ categories = list(category_names.keys())
465
+
466
+ cf_rows = []
467
+ for idx, row in base_df.iterrows():
468
+ original_text = row[text_col]
469
+ original_name = row[name_col]
470
+ original_category = row[category_col]
471
+ original_name_clean = original_name.replace("_", " ")
472
+
473
+ other_categories = [cat for cat in categories if cat != original_category]
474
+
475
+ for target_category in other_categories:
476
+ target_names = category_names[target_category]
477
+
478
+ if len(target_names) == 0:
479
+ continue
480
+
481
+ num_to_sample = min(max(1, num_cf_per_row // len(other_categories)), len(target_names))
482
+ sampled_names = np.random.choice(target_names, size=num_to_sample, replace=False)
483
+
484
+ for new_name in sampled_names:
485
+ new_name_clean = new_name.replace("_", " ")
486
+
487
+ new_text = original_text.replace(original_name_clean, new_name_clean, 1)
488
+
489
+ if new_text == original_text:
490
+ original_parts = original_name_clean.split()
491
+ for part in original_parts:
492
+ if len(part) > 2:
493
+ new_text = original_text.replace(part, new_name_clean, 1)
494
+ if new_text != original_text:
495
+ break
496
+
497
+ if new_text == original_text:
498
+ continue
499
+
500
+ new_row = row.copy()
501
+ new_row[name_col] = new_name
502
+ new_row[text_col] = new_text
503
+ new_row[category_col] = target_category
504
+ new_row["original_category"] = original_category
505
+ cf_rows.append(new_row)
506
+
507
+ counterfactual_df = pd.DataFrame(cf_rows)
508
+ augmented_df = pd.concat([base_df, counterfactual_df], ignore_index=True)
509
+
510
+ print(f"\nAugmentation Finished: Original {len(base_df)} Added {len(counterfactual_df)} ")
511
+ print(f"Total data len: {len(augmented_df)}")
512
+
513
+ return augmented_df
514
+
515
+ def _ensure_plot_saved(
516
+ df,
517
+ score_col: str,
518
+ basename: str,
519
+ group_col: str = None,
520
+ target: float = None,
521
+ bins: int = 30,
522
+ ) -> str:
523
+ os.makedirs("data", exist_ok=True)
524
+ path = os.path.join("data", f"{basename}.png")
525
+
526
+ plt.figure(figsize=(8, 5))
527
+ data = df[score_col].dropna().values
528
+
529
+ if group_col and group_col in df.columns:
530
+ for g, sub in df.groupby(group_col):
531
+ vals = sub[score_col].dropna().values
532
+ if len(vals) == 0:
533
+ continue
534
+ plt.hist(vals, bins=bins, alpha=0.4, label=f"{g} (n={len(vals)}, μ={np.mean(vals):.3f})", density=True)
535
+ else:
536
+ plt.hist(data, bins=bins, alpha=0.6, density=True, label=f"All (n={len(data)}, μ={np.mean(data):.3f})")
537
+
538
+ if len(data):
539
+ m = float(np.mean(data))
540
+ plt.axvline(m, linestyle="--", linewidth=2, label=f"mean={m:.3f}")
541
+
542
+ if target is not None:
543
+ plt.axvline(target, linestyle="-.", linewidth=2, label=f"target={target:.3f}")
544
+
545
+ plt.xlabel(score_col)
546
+ plt.ylabel("density")
547
+ plt.title(basename.replace("_", " "))
548
+ plt.legend(loc="best")
549
+ plt.tight_layout()
550
+ plt.savefig(path, dpi=160)
551
+ plt.close()
552
+ return path
frontend/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
frontend/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # React + TypeScript + Vite
2
+
3
+ This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
4
+
5
+ Currently, two official plugins are available:
6
+
7
+ - [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) for Fast Refresh
8
+ - [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
9
+
10
+ ## Expanding the ESLint configuration
11
+
12
+ If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
13
+
14
+ ```js
15
+ export default tseslint.config([
16
+ globalIgnores(['dist']),
17
+ {
18
+ files: ['**/*.{ts,tsx}'],
19
+ extends: [
20
+ // Other configs...
21
+
22
+ // Remove tseslint.configs.recommended and replace with this
23
+ ...tseslint.configs.recommendedTypeChecked,
24
+ // Alternatively, use this for stricter rules
25
+ ...tseslint.configs.strictTypeChecked,
26
+ // Optionally, add this for stylistic rules
27
+ ...tseslint.configs.stylisticTypeChecked,
28
+
29
+ // Other configs...
30
+ ],
31
+ languageOptions: {
32
+ parserOptions: {
33
+ project: ['./tsconfig.node.json', './tsconfig.app.json'],
34
+ tsconfigRootDir: import.meta.dirname,
35
+ },
36
+ // other options...
37
+ },
38
+ },
39
+ ])
40
+ ```
41
+
42
+ You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
43
+
44
+ ```js
45
+ // eslint.config.js
46
+ import reactX from 'eslint-plugin-react-x'
47
+ import reactDom from 'eslint-plugin-react-dom'
48
+
49
+ export default tseslint.config([
50
+ globalIgnores(['dist']),
51
+ {
52
+ files: ['**/*.{ts,tsx}'],
53
+ extends: [
54
+ // Other configs...
55
+ // Enable lint rules for React
56
+ reactX.configs['recommended-typescript'],
57
+ // Enable lint rules for React DOM
58
+ reactDom.configs.recommended,
59
+ ],
60
+ languageOptions: {
61
+ parserOptions: {
62
+ project: ['./tsconfig.node.json', './tsconfig.app.json'],
63
+ tsconfigRootDir: import.meta.dirname,
64
+ },
65
+ // other options...
66
+ },
67
+ },
68
+ ])
69
+ ```
frontend/eslint.config.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import tseslint from 'typescript-eslint'
6
+ import { globalIgnores } from 'eslint/config'
7
+
8
+ export default tseslint.config([
9
+ globalIgnores(['dist']),
10
+ {
11
+ files: ['**/*.{ts,tsx}'],
12
+ extends: [
13
+ js.configs.recommended,
14
+ tseslint.configs.recommended,
15
+ reactHooks.configs['recommended-latest'],
16
+ reactRefresh.configs.vite,
17
+ ],
18
+ languageOptions: {
19
+ ecmaVersion: 2020,
20
+ globals: globals.browser,
21
+ },
22
+ },
23
+ ])
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Vite + React + TS</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.tsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "lucide-react": "^0.542.0",
14
+ "react": "^19.1.1",
15
+ "react-dom": "^19.1.1",
16
+ "recharts": "^3.1.2"
17
+ },
18
+ "devDependencies": {
19
+ "@eslint/js": "^9.33.0",
20
+ "@tailwindcss/forms": "^0.5.10",
21
+ "@types/react": "^19.1.10",
22
+ "@types/react-dom": "^19.1.7",
23
+ "@vitejs/plugin-react": "^5.0.0",
24
+ "autoprefixer": "^10.4.21",
25
+ "eslint": "^9.33.0",
26
+ "eslint-plugin-react-hooks": "^5.2.0",
27
+ "eslint-plugin-react-refresh": "^0.4.20",
28
+ "globals": "^16.3.0",
29
+ "postcss": "^8.5.6",
30
+ "tailwindcss": "^3.4.17",
31
+ "typescript": "~5.8.3",
32
+ "typescript-eslint": "^8.39.1",
33
+ "vite": "^7.1.2"
34
+ }
35
+ }
frontend/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ }
frontend/public/vite.svg ADDED
frontend/src/App.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #root {
2
+ max-width: 1280px;
3
+ margin: 0 auto;
4
+ padding: 2rem;
5
+ text-align: center;
6
+ }
7
+
8
+ .logo {
9
+ height: 6em;
10
+ padding: 1.5em;
11
+ will-change: filter;
12
+ transition: filter 300ms;
13
+ }
14
+ .logo:hover {
15
+ filter: drop-shadow(0 0 2em #646cffaa);
16
+ }
17
+ .logo.react:hover {
18
+ filter: drop-shadow(0 0 2em #61dafbaa);
19
+ }
20
+
21
+ @keyframes logo-spin {
22
+ from {
23
+ transform: rotate(0deg);
24
+ }
25
+ to {
26
+ transform: rotate(360deg);
27
+ }
28
+ }
29
+
30
+ @media (prefers-reduced-motion: no-preference) {
31
+ a:nth-of-type(2) .logo {
32
+ animation: logo-spin infinite 20s linear;
33
+ }
34
+ }
35
+
36
+ .card {
37
+ padding: 2em;
38
+ }
39
+
40
+ .read-the-docs {
41
+ color: #888;
42
+ }
frontend/src/App.tsx ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, Suspense, lazy } from 'react';
2
+ import { Target } from 'lucide-react';
3
+ import type { JobConfig, Extras } from './types';
4
+ import { JobRunnerProvider, useJobRunner } from './hooks/JobRunnerProvider';
5
+
6
+ const ConfigPage = lazy(() => import('./pages/ConfigPage'));
7
+ const ResultsPage = lazy(() => import('./pages/ResultsPage'));
8
+
9
+ type Tab = 'config'|'results'|'reports';
10
+
11
+ function AppInner() {
12
+ const [tab, setTab] = useState<Tab>('config');
13
+ const { start } = useJobRunner();
14
+
15
+ const run = (cfg: JobConfig, extras: Extras) => {
16
+ start(cfg, extras);
17
+ setTab('results');
18
+ };
19
+
20
+ const tabBtn = (active: boolean) =>
21
+ `relative px-4 py-2 rounded-xl text-sm font-medium transition-all ${
22
+ active
23
+ ? 'text-white bg-gradient-to-r from-indigo-600 via-violet-600 to-fuchsia-600 shadow-lg shadow-indigo-600/20'
24
+ : 'text-slate-600 hover:text-slate-800 bg-white/70 backdrop-blur border border-white/30 hover:shadow-md'
25
+ }`;
26
+
27
+ return (
28
+ <div className="min-h-screen bg-[radial-gradient(1200px_800px_at_20%_-10%,#c7d2fe_0%,transparent_60%),radial-gradient(1200px_800px_at_120%_20%,#fbcfe8_0%,transparent_55%)] bg-slate-50">
29
+ <header className="sticky top-0 z-50">
30
+ <div className="backdrop-blur-xl bg-white/60 border-b border-white/40 shadow-[0_10px_30px_-12px_rgba(30,41,59,0.25)] supports-[backdrop-filter]:bg-white/40">
31
+ <div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 h-16 flex items-center justify-between">
32
+ <div className="flex items-center gap-3">
33
+ <div className="p-2 rounded-xl bg-gradient-to-br from-indigo-600 to-fuchsia-600 shadow-md shadow-indigo-600/30">
34
+ <Target className="w-6 h-6 text-white" />
35
+ </div>
36
+ <div>
37
+ <h1 className="text-lg sm:text-xl font-bold tracking-tight text-slate-900">AAAI Demo</h1>
38
+ <p className="text-xs text-slate-600">Rank Sampling</p>
39
+ </div>
40
+ </div>
41
+
42
+ <div className="hidden sm:flex items-center gap-2">
43
+ <span className="px-3 py-1 rounded-full text-xs font-medium bg-emerald-50 text-emerald-700 border border-emerald-200">System Health</span>
44
+ <span className="px-3 py-1 rounded-full text-xs font-medium bg-slate-100 text-slate-700 border border-slate-200">v1.0.0</span>
45
+ </div>
46
+ </div>
47
+ </div>
48
+
49
+ <div className="border-b border-white/40 bg-white/30 backdrop-blur supports-[backdrop-filter]:bg-white/20">
50
+ <div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-3 flex gap-3">
51
+ <button className={tabBtn(tab==='config')} onClick={() => setTab('config')}>Config Setting</button>
52
+ <button className={tabBtn(tab==='results')} onClick={() => setTab('results')}>Results</button>
53
+ </div>
54
+ </div>
55
+ </header>
56
+
57
+ <main className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-8">
58
+ <Suspense fallback={
59
+ <div className="p-8 rounded-2xl border border-white/30 bg-white/60 backdrop-blur text-slate-600">
60
+ Loading
61
+ </div>
62
+ }>
63
+ {/* 改成把 extras 一起傳 */}
64
+ {tab === 'config' && <ConfigPage onRun={run} />}
65
+ {tab === 'results' && <ResultsPage />}
66
+ </Suspense>
67
+ </main>
68
+ </div>
69
+ );
70
+ }
71
+
72
+ export default function App() {
73
+ return (
74
+ <JobRunnerProvider>
75
+ <AppInner />
76
+ </JobRunnerProvider>
77
+ );
78
+ }
frontend/src/assets/react.svg ADDED
frontend/src/components/MetricCard.tsx ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/components/MetricCard.tsx
2
+ import React from 'react';
3
+ import type { ReactNode } from 'react';
4
+ import { TrendingDown, TrendingUp } from 'lucide-react';
5
+
6
+ export default function MetricCard({
7
+ title, value, change, positive, icon, description
8
+ }: {
9
+ title: string; value: string; change?: string; positive?: boolean; icon: ReactNode; description?: string;
10
+ }) {
11
+ return (
12
+ <div className="bg-white rounded-lg shadow-sm border p-6 hover:shadow-md transition-shadow">
13
+ <div className="flex items-center justify-between">
14
+ <div className="flex-1">
15
+ <p className="text-sm font-medium text-gray-600">{title}</p>
16
+ <p className="text-2xl font-bold text-gray-900 mt-1">{value}</p>
17
+ {change && (
18
+ <div className="flex items-center mt-2">
19
+ {positive ? <TrendingDown className="w-4 h-4 text-green-500 mr-1" /> : <TrendingUp className="w-4 h-4 text-red-500 mr-1" />}
20
+ <span className={`text-sm font-medium ${positive ? 'text-green-600' : 'text-red-600'}`}>{change}</span>
21
+ </div>
22
+ )}
23
+ {description && <p className="text-xs text-gray-500 mt-1">{description}</p>}
24
+ </div>
25
+ <div className="p-3 bg-blue-50 rounded-full">{icon}</div>
26
+ </div>
27
+ </div>
28
+ );
29
+ }
frontend/src/components/PipelineProgress.tsx ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useMemo, useRef, useState } from 'react';
2
+ import { CheckCircle2, Loader2, Database, Brain, Sparkles, Rocket, LineChart } from 'lucide-react';
3
+ import { MLBiasAPI } from '../services/api';
4
+ import { useJobRunner } from '../hooks/JobRunnerProvider';
5
+
6
+ type Health = {
7
+ status?: string;
8
+ timestamp?: string;
9
+ loaded_models?: string[];
10
+ dataset_loaded?: boolean;
11
+ generation_results_available?: boolean;
12
+ finetune_running?: boolean;
13
+ steps?: Record<string, boolean | 'todo' | 'doing' | 'done'>;
14
+ };
15
+
16
+ type StepKey =
17
+ | 'Activate Task'
18
+ | 'Load Dataset'
19
+ | 'Load Model'
20
+ | 'Generate and Score'
21
+ | 'Counterfactual'
22
+ | 'Sampling'
23
+ | 'Plot and Output'
24
+ | 'Finetune';
25
+
26
+ type StepState = 'todo' | 'doing' | 'done';
27
+
28
+ export default function PipelineProgress() {
29
+ const { result, resp } = useJobRunner();
30
+
31
+ const [health, setHealth] = useState<Health | null>(null);
32
+ const pollRef = useRef<number | null>(null);
33
+
34
+ useEffect(() => {
35
+ const poll = async () => {
36
+ try {
37
+ const h = (await MLBiasAPI.checkHealth()) as Health;
38
+ setHealth(prev => (JSON.stringify(prev) === JSON.stringify(h) ? prev : h));
39
+ } catch {
40
+ }
41
+ };
42
+ void poll();
43
+ pollRef.current = window.setInterval(poll, 1000);
44
+ return () => {
45
+ if (pollRef.current) window.clearInterval(pollRef.current);
46
+ };
47
+ }, []);
48
+
49
+ const [elapsed, setElapsed] = useState<number>(0);
50
+ const timerRef = useRef<number | null>(null);
51
+ useEffect(() => {
52
+ const startedAt = Date.now();
53
+ timerRef.current = window.setInterval(() => {
54
+ setElapsed(Math.floor((Date.now() - startedAt) / 1000));
55
+ }, 1000);
56
+ return () => {
57
+ if (timerRef.current) window.clearInterval(timerRef.current);
58
+ };
59
+ }, []);
60
+
61
+ const modelName = result?.config?.languageModel || '';
62
+
63
+ const wantFT = Boolean(
64
+ result?.config?.enableFineTuning ?? (resp?.results as any)?.config_used?.enableFineTuning
65
+ );
66
+
67
+ const backendSteps = useMemo(() => {
68
+ const fromResp = ((resp?.results as any)?.steps || {}) as Record<string, boolean>;
69
+ const fromHealth = ((health?.steps || {}) as Record<string, boolean | 'todo' | 'doing' | 'done'>);
70
+ const merged: Record<string, boolean> = { ...fromResp };
71
+ Object.keys(fromHealth).forEach(k => {
72
+ const v = (fromHealth as any)[k];
73
+ merged[k] = v === true || v === 'doing' || v === 'done';
74
+ });
75
+ return merged;
76
+ }, [health?.steps, resp?.results]);
77
+
78
+ const resultsAny = (resp?.results ?? {}) as any;
79
+
80
+ const inferred = useMemo(() => {
81
+ const hasData = Boolean(health?.dataset_loaded);
82
+ const hasModel = Boolean(health?.loaded_models && health.loaded_models.length > 0);
83
+
84
+ const genDone = Boolean(
85
+ backendSteps['3_generate_and_eval'] ||
86
+ health?.generation_results_available ||
87
+ resultsAny.generation_done
88
+ );
89
+
90
+ const r4Flag = Boolean(
91
+ backendSteps['4_rank_sampling_original'] ||
92
+ resultsAny.rank_sampling_original_done ||
93
+ resultsAny.rank_sampling?.original_done
94
+ );
95
+
96
+ const r5Flag = Boolean(
97
+ backendSteps['5_rank_sampling_cf'] ||
98
+ resultsAny.rank_sampling_cf_done ||
99
+ resultsAny.rank_sampling?.cf_done
100
+ );
101
+
102
+ const plotsFlag = Boolean(
103
+ backendSteps['6_plots_and_metrics'] ||
104
+ resultsAny.plot_urls ||
105
+ resultsAny.plots_ready ||
106
+ (resultsAny.plots &&
107
+ (resultsAny.plots.original_sentiment || resultsAny.plots.counterfactual_sentiment))
108
+ );
109
+
110
+ const ftDoneFlag = Boolean(
111
+ backendSteps['7_finetune'] === true ||
112
+ resultsAny.finetune_done ||
113
+ resultsAny.finetune?.completed ||
114
+ resultsAny.finetune?.saved_model_path
115
+ );
116
+
117
+ const ftRunning = Boolean(resultsAny.finetune?.running || (health as any)?.finetune_running);
118
+
119
+ const noStepSignals =
120
+ Object.keys(backendSteps || {}).length === 0 &&
121
+ !resultsAny.rank_sampling_original_done &&
122
+ !resultsAny.rank_sampling_cf_done &&
123
+ !resultsAny.plots_ready &&
124
+ !resultsAny.finetune_done;
125
+
126
+ const cfByTime = noStepSignals && genDone && elapsed > 30;
127
+ const rsByTime = noStepSignals && genDone && elapsed > 45;
128
+ const plotsByTime= noStepSignals && genDone && elapsed > 70;
129
+
130
+ const cfDone = Boolean(
131
+ backendSteps['3_5_counterfactual'] ||
132
+ resultsAny.counterfactual_done ||
133
+ resultsAny.counterfactual_results ||
134
+ r4Flag || r5Flag || plotsFlag || ftDoneFlag ||
135
+ cfByTime
136
+ );
137
+
138
+ const r4 = r4Flag || rsByTime;
139
+ const r5 = r5Flag || rsByTime;
140
+ const plots = plotsFlag || plotsByTime;
141
+ const ftDone = ftDoneFlag;
142
+
143
+ return { hasData, hasModel, genDone, cfDone, r4, r5, plots, ftDone, ftRunning };
144
+ }, [backendSteps, health, resultsAny, elapsed]);
145
+
146
+
147
+ const rawSteps = useMemo<Record<StepKey, StepState>>(() => {
148
+ const states: Record<StepKey, StepState> = {
149
+ 'Activate Task': 'todo',
150
+ 'Load Dataset': 'todo',
151
+ 'Load Model': 'todo',
152
+ 'Generate and Score': 'todo',
153
+ 'Counterfactual': 'todo',
154
+ 'Sampling': 'todo',
155
+ 'Plot and Output': 'todo',
156
+ 'Finetune': 'todo',
157
+ };
158
+
159
+ if (result?.status === 'running') {
160
+ states['Activate Task'] = 'doing';
161
+ }
162
+ if (inferred.hasData) {
163
+ states['Activate Task'] = 'done';
164
+ states['Load Dataset'] = 'done';
165
+ }
166
+
167
+ if (inferred.hasModel) {
168
+ states['Load Model'] = 'done';
169
+ } else if (inferred.hasData) {
170
+ states['Load Model'] = 'doing';
171
+ }
172
+
173
+ if (inferred.genDone) {
174
+ states['Generate and Score'] = 'done';
175
+ } else if (inferred.hasModel) {
176
+ states['Generate and Score'] = 'doing';
177
+ }
178
+
179
+ if (inferred.cfDone) {
180
+ states['Counterfactual'] = 'done';
181
+ } else if (states['Generate and Score'] === 'done') {
182
+ states['Counterfactual'] = 'doing';
183
+ }
184
+
185
+ const shouldStartSampling =
186
+ inferred.r4 || inferred.r5 ||
187
+ states['Counterfactual'] === 'done' ||
188
+ (states['Generate and Score'] === 'done' && elapsed > 20);
189
+
190
+ if (inferred.r4 && inferred.r5) {
191
+ states['Sampling'] = 'done';
192
+ } else if (shouldStartSampling) {
193
+ states['Sampling'] = 'doing';
194
+ }
195
+
196
+ const shouldStartPlotting =
197
+ inferred.plots ||
198
+ states['Sampling'] === 'done' ||
199
+ (states['Sampling'] === 'doing' && elapsed > 40);
200
+
201
+ if (inferred.plots) {
202
+ states['Plot and Output'] = 'done';
203
+ } else if (shouldStartPlotting) {
204
+ states['Plot and Output'] = 'doing';
205
+ }
206
+
207
+ if (wantFT) {
208
+ if (inferred.ftDone) states['Finetune'] = 'done';
209
+ else if (inferred.ftRunning || states['Plot and Output'] === 'done')
210
+ states['Finetune'] = 'doing';
211
+ else states['Finetune'] = 'todo';
212
+ } else {
213
+ states['Finetune'] = 'todo';
214
+ }
215
+
216
+ return states;
217
+ }, [elapsed, inferred, wantFT, result?.status]);
218
+
219
+ const STUCK_TIMEOUT = 30; // 秒
220
+ const [enteredAt, setEnteredAt] = useState<Record<StepKey, number>>({} as any);
221
+ const [forcedDone, setForcedDone] = useState<Record<StepKey, boolean>>({} as any);
222
+
223
+ useEffect(() => {
224
+ const next: Record<StepKey, number> = { ...enteredAt } as any;
225
+ (Object.keys(rawSteps) as StepKey[]).forEach((k) => {
226
+ if (rawSteps[k] === 'doing' && !next[k]) next[k] = Date.now();
227
+ if (rawSteps[k] !== 'doing' && next[k]) delete next[k];
228
+ });
229
+ if (JSON.stringify(next) !== JSON.stringify(enteredAt)) setEnteredAt(next);
230
+ }, [rawSteps]);
231
+
232
+ useEffect(() => {
233
+ const now = Date.now();
234
+ const k: StepKey = 'Counterfactual';
235
+ if (rawSteps[k] === 'doing' && enteredAt[k] && now - enteredAt[k] > STUCK_TIMEOUT * 1000) {
236
+ if (!forcedDone[k]) setForcedDone(prev => ({ ...prev, [k]: true }));
237
+ }
238
+ }, [enteredAt, rawSteps, forcedDone]);
239
+
240
+ const steps = useMemo(() => {
241
+ const s = { ...rawSteps } as Record<StepKey, StepState>;
242
+ (Object.keys(forcedDone) as StepKey[]).forEach((k) => {
243
+ if (forcedDone[k]) s[k] = 'done';
244
+ });
245
+ return s;
246
+ }, [rawSteps, forcedDone]);
247
+
248
+ const ft = resultsAny?.finetune || {};
249
+ const downloadPath: string | undefined =
250
+ ft.download_url || ft.model_url || ft.saved_model_path || resultsAny?.finetune_model_url;
251
+
252
+ const downloadHref = downloadPath ? MLBiasAPI.resolvePath(downloadPath) : undefined;
253
+
254
+ const baseSteps = [
255
+ { key: 'Activate Task', icon: Rocket },
256
+ { key: 'Load Dataset', icon: Database },
257
+ { key: 'Load Model', icon: Brain },
258
+ { key: 'Generate and Score', icon: Sparkles },
259
+ { key: 'Counterfactual', icon: Sparkles },
260
+ { key: 'Sampling', icon: LineChart },
261
+ { key: 'Plot and Output', icon: LineChart },
262
+ ] as const;
263
+
264
+ const stepList = wantFT
265
+ ? [...baseSteps, { key: 'Finetune', icon: Rocket } as const]
266
+ : baseSteps;
267
+
268
+ const completedCount = stepList.reduce(
269
+ (acc, s) => acc + (steps[s.key as StepKey] === 'done' ? 1 : 0),
270
+ 0
271
+ );
272
+ const doingCount = stepList.reduce(
273
+ (acc, s) => acc + (steps[s.key as StepKey] === 'doing' ? 1 : 0),
274
+ 0
275
+ );
276
+ const percent = Math.min(
277
+ 100,
278
+ Math.round(((completedCount + doingCount * 0.5) / stepList.length) * 100)
279
+ );
280
+
281
+ const hasStuckStep =
282
+ Object.values(steps).some((state) => state === 'doing') &&
283
+ elapsed > 60 &&
284
+ completedCount < stepList.length - 1;
285
+
286
+ return (
287
+ <div className="relative overflow-hidden rounded-2xl border border-indigo-200/50 bg-white/70 backdrop-blur">
288
+ <div className="absolute inset-0 -z-10 bg-[radial-gradient(800px_300px_at_20%_-20%,rgba(99,102,241,0.15),transparent_60%),radial-gradient(800px_300px_at_120%_0%,rgba(244,114,182,0.15),transparent_60%)]" />
289
+ <div className="p-6">
290
+ <div className="flex items-center justify-between mb-4">
291
+ <div>
292
+ <h3 className="text-lg font-semibold text-slate-900">Pipeline Running</h3>
293
+ <p className="text-sm text-slate-600">
294
+ Model: <span className="font-medium text-slate-800">{modelName || '(未指定)'}</span>
295
+ </p>
296
+ {hasStuckStep && (
297
+ <p className="text-xs text-amber-600 mt-1">⚠️ Some steps may run slowly and are automatically attempted to proceed.</p>
298
+ )}
299
+ </div>
300
+ <div className="flex items-center gap-2 text-slate-600">
301
+ <Loader2 className="w-5 h-5 animate-spin" />
302
+ <span className="text-sm">Executed {elapsed}s</span>
303
+ </div>
304
+ </div>
305
+
306
+ <div className="w-full h-3 rounded-full bg-slate-200 overflow-hidden">
307
+ <div
308
+ className="h-3 bg-gradient-to-r from-indigo-500 via-violet-500 to-fuchsia-500 transition-all duration-500"
309
+ style={{ width: `${percent}%` }}
310
+ />
311
+ </div>
312
+
313
+ <ol className="mt-6 grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3">
314
+ {stepList.map(({ key, icon: Icon }) => {
315
+ const state = steps[key as StepKey];
316
+ const isDone = state === 'done';
317
+ const isDoing = state === 'doing';
318
+ const startTs = enteredAt[key as StepKey];
319
+ const isStuck = isDoing && startTs && (Date.now() - startTs) / 1000 > STUCK_TIMEOUT;
320
+
321
+ return (
322
+ <li
323
+ key={key}
324
+ className={`flex items-center gap-3 rounded-xl border p-3 transition-all duration-300 ${
325
+ isDone
326
+ ? 'border-emerald-200 bg-emerald-50'
327
+ : isDoing
328
+ ? isStuck
329
+ ? 'border-amber-300 bg-amber-100'
330
+ : 'border-amber-200 bg-amber-50'
331
+ : 'border-slate-200 bg-white/70'
332
+ }`}
333
+ >
334
+ <div
335
+ className={`rounded-lg p-2 transition-colors ${
336
+ isDone
337
+ ? 'bg-emerald-100 text-emerald-700'
338
+ : isDoing
339
+ ? isStuck
340
+ ? 'bg-amber-200 text-amber-800'
341
+ : 'bg-amber-100 text-amber-700'
342
+ : 'bg-slate-100 text-slate-600'
343
+ }`}
344
+ >
345
+ {isDone ? (
346
+ <CheckCircle2 className="w-5 h-5" />
347
+ ) : isDoing ? (
348
+ <Loader2 className="w-5 h-5 animate-spin" />
349
+ ) : (
350
+ <Icon className="w-5 h-5" />
351
+ )}
352
+ </div>
353
+ <div className="flex-1">
354
+ <div className="text-sm font-medium text-slate-900">{key}</div>
355
+ <div className="text-xs text-slate-600">
356
+ {isDone ? 'Finished' : isDoing ? (isStuck ? 'Running...' : 'Running…') : 'Waiting'}
357
+ </div>
358
+ </div>
359
+ </li>
360
+ );
361
+ })}
362
+ </ol>
363
+
364
+ {/* 微調完成 → 顯示下載模型 */}
365
+ {wantFT && inferred.ftDone && downloadHref && (
366
+ <div className="mt-6">
367
+ <a
368
+ href={downloadHref}
369
+ className="inline-flex items-center rounded-xl border bg-white px-4 py-2 text-sm font-medium text-slate-800 hover:bg-slate-50"
370
+ target="_blank"
371
+ rel="noopener noreferrer"
372
+ >
373
+ <Rocket className="w-4 h-4 mr-2" />
374
+ Download Finetuned Model
375
+ </a>
376
+ {ft?.saved_model_path && (
377
+ <p className="mt-2 text-xs text-slate-500 break-all">
378
+ Model path: {String(ft.saved_model_path)}
379
+ </p>
380
+ )}
381
+ </div>
382
+ )}
383
+ </div>
384
+ </div>
385
+ );
386
+ }
frontend/src/components/validators/DatasetValidator.tsx ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/components/validators/DatasetValidator.tsx
2
+ import { useEffect } from 'react';
3
+ import { Loader, CheckCircle, XCircle, ExternalLink } from 'lucide-react';
4
+ import { useHFDatasetValidator } from '../../hooks/useHFValidators';
5
+
6
+ export default function DatasetValidator({ datasetId }: { datasetId: string }) {
7
+ const { loading, result, validate } = useHFDatasetValidator();
8
+
9
+ useEffect(() => {
10
+ validate(datasetId);
11
+ }, [datasetId, validate]);
12
+
13
+ if (!datasetId?.includes('/')) return null;
14
+
15
+ return (
16
+ <div className="mt-3">
17
+ {loading && (
18
+ <div className="flex items-center gap-2 p-3 bg-yellow-50 rounded-lg">
19
+ <Loader className="w-4 h-4 text-yellow-600 animate-spin" />
20
+ <span className="text-sm text-yellow-800">Validating Dataset...</span>
21
+ </div>
22
+ )}
23
+
24
+ {!!result && !loading && (
25
+ <div className={`p-3 rounded-lg ${result.isValid ? 'bg-green-50' : 'bg-red-50'}`}>
26
+ <div className="flex items-start gap-2">
27
+ {result.isValid ? (
28
+ <CheckCircle className="w-4 h-4 text-green-600 mt-0.5" />
29
+ ) : (
30
+ <XCircle className="w-4 h-4 text-red-600 mt-0.5" />
31
+ )}
32
+ <div className="flex-1">
33
+ {result.isValid ? (
34
+ <>
35
+ <p className="text-sm font-medium text-green-800">✅ Dataset Verification Successful</p>
36
+ <div className="mt-2 space-y-1 text-xs text-green-700">
37
+ <p>
38
+ <strong>Author: </strong>
39
+ {result.datasetInfo.author}
40
+ </p>
41
+ <p>
42
+ <strong>Download: </strong>
43
+ {result.datasetInfo.downloads.toLocaleString()}
44
+ </p>
45
+ {!!result.datasetInfo.task_categories?.length && (
46
+ <p>
47
+ <strong>Task: </strong>
48
+ {result.datasetInfo.task_categories.join(', ')}
49
+ </p>
50
+ )}
51
+ <p>
52
+ <strong>Description: </strong>
53
+ {result.datasetInfo.description.slice(0, 100)}...
54
+ </p>
55
+ </div>
56
+ <a
57
+ href={`https://huggingface.co/datasets/${datasetId}`}
58
+ target="_blank"
59
+ rel="noreferrer"
60
+ className="inline-flex items-center gap-1 text-xs text-blue-600 hover:text-blue-800 mt-2"
61
+ >
62
+ <ExternalLink className="w-3 h-3" />
63
+ <span>View on Hugging Face</span>
64
+ </a>
65
+ </>
66
+ ) : (
67
+ <>
68
+ <p className="text-sm font-medium text-red-800">❌ Dataset Verification Failed</p>
69
+ <p className="text-xs text-red-700 mt-1">{result.error}</p>
70
+ </>
71
+ )}
72
+ </div>
73
+ </div>
74
+ </div>
75
+ )}
76
+ </div>
77
+ );
78
+ }
frontend/src/components/validators/ModelValidator.tsx ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/components/validators/ModelValidator.tsx
2
+ import { useEffect } from 'react';
3
+ import { Loader, CheckCircle, XCircle, ExternalLink } from 'lucide-react';
4
+ import { useHFModelValidator } from '../../hooks/useHFValidators';
5
+
6
+ export default function ModelValidator({
7
+ modelId,
8
+ type,
9
+ }: {
10
+ modelId: string;
11
+ type: 'language' | 'scorer';
12
+ }) {
13
+ const { loading, result, validate } = useHFModelValidator();
14
+
15
+ useEffect(() => {
16
+ validate(modelId, type);
17
+ }, [modelId, type, validate]);
18
+
19
+ if (!modelId?.includes('/')) return null;
20
+
21
+ return (
22
+ <div className="mt-3">
23
+ {loading && (
24
+ <div className="flex items-center gap-2 p-3 bg-yellow-50 rounded-lg">
25
+ <Loader className="w-4 h-4 text-yellow-600 animate-spin" />
26
+ <span className="text-sm text-yellow-800">Validating Model...</span>
27
+ </div>
28
+ )}
29
+
30
+ {!!result && !loading && (
31
+ <div className={`p-3 rounded-lg ${result.isValid ? 'bg-green-50' : 'bg-red-50'}`}>
32
+ <div className="flex items-start gap-2">
33
+ {result.isValid ? (
34
+ <CheckCircle className="w-4 h-4 text-green-600 mt-0.5" />
35
+ ) : (
36
+ <XCircle className="w-4 h-4 text-red-600 mt-0.5" />
37
+ )}
38
+ <div className="flex-1">
39
+ {result.isValid ? (
40
+ <>
41
+ <p className="text-sm font-medium text-green-800">✅ Model Verification Sucessful</p>
42
+ <div className="mt-2 space-y-1 text-xs text-green-700">
43
+ <p>
44
+ <strong>Author: </strong>
45
+ {result.modelInfo.author}
46
+ </p>
47
+ <p>
48
+ <strong>Download: </strong>
49
+ {result.modelInfo.downloads.toLocaleString()}
50
+ </p>
51
+ {!!result.modelInfo.pipeline_tag && (
52
+ <p>
53
+ <strong>Task: </strong>
54
+ {result.modelInfo.pipeline_tag}
55
+ </p>
56
+ )}
57
+ </div>
58
+ <a
59
+ href={`https://huggingface.co/${modelId}`}
60
+ target="_blank"
61
+ rel="noreferrer"
62
+ className="inline-flex items-center gap-1 text-xs text-blue-600 hover:text-blue-800 mt-2"
63
+ >
64
+ <ExternalLink className="w-3 h-3" />
65
+ <span>View on Hugging Face</span>
66
+ </a>
67
+ </>
68
+ ) : (
69
+ <>
70
+ <p className="text-sm font-medium text-red-800">❌ Model Verification Failed</p>
71
+ <p className="text-xs text-red-700 mt-1">{result.error}</p>
72
+ </>
73
+ )}
74
+ </div>
75
+ </div>
76
+ </div>
77
+ )}
78
+ </div>
79
+ );
80
+ }
frontend/src/constants/datasets.ts ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { Dataset } from '../types';
2
+
3
+ export const DATASETS: Dataset[] = [
4
+ {
5
+ id: 'AmazonScience/bold',
6
+ name: 'BOLD'
7
+ },
8
+ ];
9
+
10
+ export default DATASETS;
frontend/src/constants/models.ts ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { Dataset, Model } from '../types';
2
+
3
+ export const DATASETS: Dataset[] = [
4
+ { id: 'AmazonScience/bold', name: 'BOLD' }
5
+ ];
6
+
7
+ export const LM_MODELS: Model[] = [
8
+ { id: 'microsoft/DialoGPT-large', name: 'DialoGPT-large', type: 'language', description: 'Microsoft 對話生成模型', provider: 'Microsoft' },
9
+ { id: 'openai-community/gpt2', name: 'GPT-2', type: 'language', description: 'OpenAI GPT-2 基礎模型', provider: 'OpenAI' },
10
+ { id: 'EleutherAI/gpt-neo-2.7B', name: 'GPT-Neo-2.7B', type: 'language', description: 'EleutherAI 開源語言模型', provider: 'EleutherAI' }
11
+ ];
frontend/src/hooks/JobRunnerProvider.tsx ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useContext, useMemo, useState } from 'react';
2
+ import type { JobConfig, JobResult } from '../types';
3
+ import { MLBiasAPI } from '../services/api';
4
+
5
+ type PipelinePlots = {
6
+ original_sentiment: string;
7
+ counterfactual_sentiment: string;
8
+ };
9
+
10
+ type PipelineResultsDTO = {
11
+ generation_file: string;
12
+ sentiment_subset_file: string;
13
+ cf_sentiment_subset_file: string;
14
+ metrics: {
15
+ finalMeanDiff: number;
16
+ cfFinalMeanDiff: number;
17
+ reductionPct?: number;
18
+ stableCoverage?: number;
19
+ };
20
+ plots: PipelinePlots;
21
+ finetuned_model_zip?: string;
22
+ finetuned_model_dir?: string;
23
+ run_config_files?: {
24
+ json?: string;
25
+ markdown?: string;
26
+ };
27
+ };
28
+
29
+ type PipelineResponseDTO = {
30
+ status: 'success' | 'error';
31
+ message: string;
32
+ timestamp: string;
33
+ results: PipelineResultsDTO;
34
+ };
35
+
36
+ type Extras = {
37
+ datasetLimit: number
38
+ };
39
+
40
+ type Ctx = {
41
+ result: JobResult | null;
42
+ resp?: PipelineResponseDTO;
43
+ loading: boolean;
44
+ error?: string;
45
+ start: (cfg: JobConfig, extras: Extras) => Promise<void>;
46
+ url: (p?: string) => string;
47
+ };
48
+
49
+ const JobRunnerContext = createContext<Ctx | undefined>(undefined);
50
+
51
+ export function JobRunnerProvider({ children }: { children: React.ReactNode }) {
52
+ const [result, setResult] = useState<JobResult | null>(null);
53
+ const [resp, setResp] = useState<PipelineResponseDTO | undefined>();
54
+ const [loading, setLoading] = useState(false);
55
+ const [error, setErr] = useState<string | undefined>();
56
+
57
+ const start: Ctx['start'] = async (cfg, extras) => {
58
+ setLoading(true);
59
+ setErr(undefined);
60
+ setResp(undefined);
61
+
62
+ const now = new Date().toISOString();
63
+ setResult({
64
+ id: crypto.randomUUID(),
65
+ status: 'running',
66
+ progress: 0,
67
+ config: cfg,
68
+ createdAt: now,
69
+ updatedAt: now,
70
+ });
71
+
72
+ try {
73
+ const cfgToSend = {
74
+ ...cfg,
75
+ datasetLimit: extras.datasetLimit
76
+ } as unknown as JobConfig;
77
+
78
+ const r = await MLBiasAPI.runPipeline(cfgToSend as any);
79
+
80
+ setResp(r);
81
+
82
+ const done = new Date().toISOString();
83
+ setResult((prev) => ({
84
+ ...(prev as JobResult),
85
+ status: 'completed',
86
+ progress: 100,
87
+ updatedAt: done,
88
+ completedAt: done,
89
+ metrics: {
90
+ finalMeanDiff: r.results.metrics.finalMeanDiff,
91
+ reductionPct: r.results.metrics.reductionPct ?? 0,
92
+ stableCoverage: r.results.metrics.stableCoverage ?? 100,
93
+ },
94
+ }));
95
+ } catch (e: any) {
96
+ setErr(e.message || String(e));
97
+ setResult((prev) =>
98
+ prev
99
+ ? { ...prev, status: 'failed', progress: 100, updatedAt: new Date().toISOString() }
100
+ : prev
101
+ );
102
+ } finally {
103
+ setLoading(false);
104
+ }
105
+ };
106
+
107
+ const url = MLBiasAPI.resolvePath;
108
+
109
+ const value = useMemo<Ctx>(
110
+ () => ({ result, resp, loading, error, start, url }),
111
+ [result, resp, loading, error]
112
+ );
113
+
114
+ return <JobRunnerContext.Provider value={value}>{children}</JobRunnerContext.Provider>;
115
+ }
116
+
117
+ export function useJobRunner() {
118
+ const ctx = useContext(JobRunnerContext);
119
+ if (!ctx) throw new Error('useJobRunner must be used within JobRunnerProvider');
120
+ return ctx;
121
+ }
frontend/src/hooks/useHFValidators.ts ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/hooks/useHFValidators.ts
2
+ import { useMemo, useRef, useState } from 'react';
3
+ import { fetchHFModel, fetchHFDataset } from '../services/hf';
4
+
5
+ const debounce = (fn: (...args: any[]) => void, ms = 350) => {
6
+ let t: any;
7
+ return (...args: any[]) => {
8
+ clearTimeout(t);
9
+ t = setTimeout(() => fn(...args), ms);
10
+ };
11
+ };
12
+
13
+ export function useHFModelValidator() {
14
+ const cache = useRef<Map<string, any>>(new Map());
15
+ const [loading, setLoading] = useState(false);
16
+ const [result, setResult] = useState<any>(null);
17
+
18
+ const validate = useMemo(
19
+ () =>
20
+ debounce(async (modelId: string, expected: 'language' | 'scorer') => {
21
+ if (!modelId?.includes('/')) {
22
+ setResult(null);
23
+ return;
24
+ }
25
+ if (cache.current.has(modelId)) {
26
+ setResult(cache.current.get(modelId));
27
+ return;
28
+ }
29
+ setLoading(true);
30
+ try {
31
+ const info = await fetchHFModel(modelId);
32
+ const id: string = info.id?.toLowerCase() ?? '';
33
+ const tags: string[] = info.tags ?? [];
34
+ let actual: string | undefined = info.pipeline_tag;
35
+
36
+ if (!actual) {
37
+ if (id.includes('t5') || tags.includes('text2text-generation')) actual = 'text2text-generation';
38
+ else if (tags.includes('text-generation')) actual = 'text-generation';
39
+ else if (tags.includes('text-classification')) actual = 'text-classification';
40
+ }
41
+
42
+ const ok =
43
+ expected === 'language'
44
+ ? ['text-generation', 'text2text-generation'].includes(actual || '')
45
+ : actual === 'text-classification';
46
+
47
+ const payload = ok
48
+ ? {
49
+ isValid: true,
50
+ modelInfo: {
51
+ id: info.id,
52
+ downloads: info.downloads ?? 0,
53
+ pipeline_tag: actual,
54
+ tags,
55
+ author: info.id?.split('/')?.[0] ?? 'unknown',
56
+ modelName: info.id?.split('/')?.[1] ?? info.id,
57
+ },
58
+ }
59
+ : {
60
+ isValid: false,
61
+ error:
62
+ expected === 'language'
63
+ ? `Model task should be text-generation or text2text-generation, but is ${actual || 'Unknown'}」`
64
+ : `Model task should be text-classification, but is${actual || 'Unknown'}」`,
65
+ };
66
+
67
+ cache.current.set(modelId, payload);
68
+ setResult(payload);
69
+ } catch (e: any) {
70
+ setResult({ isValid: false, error: e?.message || 'Error when valiating model' });
71
+ } finally {
72
+ setLoading(false);
73
+ }
74
+ }),
75
+ []
76
+ );
77
+
78
+ return { loading, result, validate };
79
+ }
80
+
81
+ export function useHFDatasetValidator() {
82
+ const cache = useRef<Map<string, any>>(new Map());
83
+ const [loading, setLoading] = useState(false);
84
+ const [result, setResult] = useState<any>(null);
85
+
86
+ const validate = useMemo(
87
+ () =>
88
+ debounce(async (datasetId: string) => {
89
+ if (!datasetId?.includes('/')) {
90
+ setResult(null);
91
+ return;
92
+ }
93
+ if (cache.current.has(datasetId)) {
94
+ setResult(cache.current.get(datasetId));
95
+ return;
96
+ }
97
+ const valid = /^[a-zA-Z0-9._-]+\/[a-zA-Z0-9._-]+$/.test(datasetId);
98
+ if (!valid) {
99
+ setResult({ isValid: false, error: 'Incorrect Dataset ID Format' });
100
+ return;
101
+ }
102
+ setLoading(true);
103
+ try {
104
+ const info = await fetchHFDataset(datasetId);
105
+ const payload = {
106
+ isValid: true,
107
+ datasetInfo: {
108
+ id: info.id,
109
+ author: info.id?.split('/')?.[0] ?? 'unknown',
110
+ datasetName: info.id?.split('/')?.[1] ?? info.id,
111
+ downloads: info.downloads ?? 0,
112
+ tags: info.tags ?? [],
113
+ description: info.description ?? 'No Description',
114
+ task_categories: info.task_categories ?? [],
115
+ },
116
+ };
117
+ cache.current.set(datasetId, payload);
118
+ setResult(payload);
119
+ } catch (e: any) {
120
+ setResult({ isValid: false, error: e?.message || 'An error occurred while validating the dataset' });
121
+ } finally {
122
+ setLoading(false);
123
+ }
124
+ }),
125
+ []
126
+ );
127
+
128
+ return { loading, result, validate };
129
+ }
frontend/src/hooks/useIterationData.ts ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo } from 'react';
2
+
3
+ export function useIterationData(seed = 42, points = 50) {
4
+ function mulberry32(a:number){return function(){let t=(a+=0x6D2B79F5);t=Math.imul(t^(t>>>15),t|1);t^=t+Math.imul(t^(t>>>7),t|61);return ((t^(t>>>14))>>>0)/4294967296;};}
5
+ return useMemo(() => {
6
+ const rand = mulberry32(seed);
7
+ return Array.from({ length: points }, (_, i) => ({
8
+ iteration: i+1,
9
+ meanDifference: Math.max(0.1, 0.8 - i*0.012 + rand()*0.1),
10
+ groupA: 0.7 - i*0.006 + rand()*0.05,
11
+ groupB: 0.3 + i*0.003 + rand()*0.05,
12
+ }));
13
+ }, [seed, points]);
14
+ }
frontend/src/hooks/useJobRunner.ts ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useMemo, useRef, useState } from 'react';
2
+ import type { JobConfig, JobResult } from '../types';
3
+ import type { PipelineResponseDTO } from '../services/api';
4
+ import { MLBiasAPI } from '../services/api';
5
+
6
+ type HealthLike = {
7
+ job_id?: string;
8
+ timestamp?: string;
9
+ updated_at?: string;
10
+ dataset_loaded?: boolean;
11
+ loaded_models?: string[];
12
+ generation_results_available?: boolean;
13
+ finetune_running?: boolean;
14
+ steps?: Record<string, boolean | 'todo' | 'doing' | 'done'>;
15
+ completed?: boolean;
16
+ status?: string;
17
+ };
18
+
19
+ type UseJobRunnerReturn = {
20
+ result: JobResult | null;
21
+ resp: PipelineResponseDTO | undefined;
22
+ loading: boolean;
23
+ error?: string;
24
+
25
+ start: (config: JobConfig) => Promise<void>;
26
+ cancel: () => void;
27
+
28
+ jobId: string | null;
29
+ live: {
30
+ health: HealthLike | null;
31
+ steps: Record<string, boolean>;
32
+ updatedAt: string | null;
33
+ finetuneRunning: boolean;
34
+ progressPercent: number;
35
+ };
36
+
37
+ url: typeof MLBiasAPI.resolvePath;
38
+ };
39
+
40
+ export function useJobRunner(): UseJobRunnerReturn {
41
+ const [jobId, setJobId] = useState<string | null>(null);
42
+
43
+ const [result, setResult] = useState<JobResult | null>(null);
44
+ const [resp, setResp] = useState<PipelineResponseDTO | undefined>();
45
+
46
+ const [loading, setLoading] = useState(false);
47
+ const [error, setErr] = useState<string | undefined>();
48
+
49
+ const [health, setHealth] = useState<HealthLike | null>(null);
50
+
51
+ const pollRef = useRef<number | null>(null);
52
+ const aliveRef = useRef<boolean>(false);
53
+
54
+ const stopPolling = () => {
55
+ if (pollRef.current) {
56
+ window.clearInterval(pollRef.current);
57
+ pollRef.current = null;
58
+ }
59
+ aliveRef.current = false;
60
+ };
61
+
62
+ const cancel = () => {
63
+ stopPolling();
64
+ setLoading(false);
65
+ };
66
+
67
+ const progressPercent = useMemo(() => {
68
+ const s = (health?.steps as Record<string, boolean | string>) || {};
69
+ const keys = Object.keys(s);
70
+ if (keys.length === 0) return result?.progress ?? 0;
71
+ let score = 0;
72
+ keys.forEach((k) => {
73
+ const v = s[k];
74
+ if (v === true || v === 'done') score += 1;
75
+ else if (v === 'doing') score += 0.5;
76
+ });
77
+ return Math.max(0, Math.min(100, Math.round((score / keys.length) * 100)));
78
+ }, [health?.steps, result?.progress]);
79
+
80
+ const liveSteps: Record<string, boolean> = useMemo(() => {
81
+ const fromResp = ((resp?.results as any)?.steps || {}) as Record<string, boolean>;
82
+ const fromHealth = ((health?.steps || {}) as Record<string, boolean | string>);
83
+ const normalized: Record<string, boolean> = {};
84
+ Object.keys(fromResp).forEach((k) => (normalized[k] = !!(fromResp as any)[k]));
85
+ Object.keys(fromHealth).forEach((k) => {
86
+ const v = (fromHealth as any)[k];
87
+ normalized[k] = v === true || v === 'done' || v === 'doing';
88
+ });
89
+ return normalized;
90
+ }, [health?.steps, resp?.results]);
91
+
92
+ const pollOnce = async () => {
93
+ try {
94
+ const h = (await MLBiasAPI.checkHealth()) as HealthLike;
95
+ setHealth((prev) => (JSON.stringify(prev) === JSON.stringify(h) ? prev : h));
96
+
97
+ const steps = (h?.steps || {}) as Record<string, boolean | string>;
98
+ const plotsDone =
99
+ !!steps['6_plots_and_metrics'] ||
100
+ (resp?.results as any)?.plots_ready ||
101
+ ((resp?.results as any)?.plot_urls?.length ?? 0) > 0;
102
+
103
+ const r4 = !!steps['4_rank_sampling_original'];
104
+ const r5 = !!steps['5_rank_sampling_cf'];
105
+ const samplingDone = r4 && r5;
106
+
107
+ const genAvailable = !!h?.generation_results_available;
108
+ const ftMaybeDone =
109
+ !!steps['7_finetune'] ||
110
+ (resp?.results as any)?.finetune_done ||
111
+ (resp?.results as any)?.finetune?.completed;
112
+
113
+ const declaredCompleted = h?.completed === true || h?.status === 'completed';
114
+
115
+ if (declaredCompleted || plotsDone || samplingDone || (genAvailable && ftMaybeDone)) {
116
+ stopPolling();
117
+ setLoading(false);
118
+ }
119
+ } catch (e: any) {
120
+ setErr((e && e.message) || String(e));
121
+ }
122
+ };
123
+
124
+ const start = async (config: JobConfig) => {
125
+ setLoading(true);
126
+ setErr(undefined);
127
+
128
+ const now = new Date().toISOString();
129
+ const provisionalId = crypto.randomUUID();
130
+ setResult({
131
+ id: provisionalId,
132
+ status: 'running',
133
+ progress: 0,
134
+ config,
135
+ createdAt: now,
136
+ updatedAt: now,
137
+ });
138
+ setResp(undefined);
139
+ setHealth(null);
140
+
141
+ try {
142
+ const runResp: any = await MLBiasAPI.runPipeline(config);
143
+
144
+ const jid: string | undefined =
145
+ runResp?.jobId || runResp?.job_id || runResp?.results?.jobId || runResp?.results?.job_id;
146
+ setJobId(jid || provisionalId);
147
+
148
+ if (runResp?.results?.metrics) {
149
+ const final = runResp as PipelineResponseDTO;
150
+ const now2 = new Date().toISOString();
151
+
152
+ setResp(final);
153
+ setResult({
154
+ id: jid || provisionalId,
155
+ status: 'completed',
156
+ progress: 100,
157
+ config,
158
+ createdAt: now,
159
+ updatedAt: now2,
160
+ completedAt: now2,
161
+ metrics: {
162
+ finalMeanDiff: final.results.metrics.finalMeanDiff,
163
+ reductionPct: final.results.metrics.reductionPct ?? 0,
164
+ stableCoverage: final.results.metrics.stableCoverage ?? 100,
165
+ },
166
+ });
167
+ setLoading(false);
168
+ return;
169
+ }
170
+
171
+ aliveRef.current = true;
172
+ await pollOnce();
173
+ if (aliveRef.current) {
174
+ pollRef.current = window.setInterval(pollOnce, 1000);
175
+ }
176
+ } catch (e: any) {
177
+ setErr(e?.message || String(e));
178
+ setResult((prev) =>
179
+ prev
180
+ ? { ...prev, status: 'failed', progress: 100, updatedAt: new Date().toISOString() }
181
+ : null
182
+ );
183
+ setLoading(false);
184
+ }
185
+ };
186
+
187
+ useEffect(() => stopPolling, []);
188
+
189
+ const url = MLBiasAPI.resolvePath;
190
+
191
+ return {
192
+ result,
193
+ resp,
194
+ loading,
195
+ error,
196
+ start,
197
+ cancel,
198
+ jobId,
199
+ live: {
200
+ health,
201
+ steps: liveSteps,
202
+ updatedAt: (health && (health.updated_at || health.timestamp)) || null,
203
+ finetuneRunning: !!(health?.finetune_running || (resp as any)?.results?.finetune?.running),
204
+ progressPercent,
205
+ },
206
+ url,
207
+ };
208
+ }
frontend/src/index.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
frontend/src/main.tsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import ReactDOM from 'react-dom/client'
3
+ import App from './App.tsx'
4
+ import './index.css'
5
+
6
+ ReactDOM.createRoot(document.getElementById('root')!).render(
7
+ <React.StrictMode>
8
+ <App />
9
+ </React.StrictMode>,
10
+ )
frontend/src/pages/ConfigPage.tsx ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from 'react';
2
+ import { Database, Bot, ExternalLink, Shuffle } from 'lucide-react';
3
+ import DatasetValidator from '../components/validators/DatasetValidator';
4
+ import ModelValidator from '../components/validators/ModelValidator';
5
+ import { DATASETS } from '../constants/datasets';
6
+ import { LM_MODELS} from '../constants/models';
7
+ import type { JobConfig } from '../types';
8
+
9
+ type Extras = {
10
+ datasetLimit: number,
11
+ };
12
+
13
+ export default function ConfigPage({ onRun }: { onRun: (cfg: JobConfig, extras: Extras) => void }) {
14
+ const [cfg, setCfg] = useState<JobConfig>({
15
+ dataset: '',
16
+ languageModel: '',
17
+ scorerModel: '',
18
+ k: 5,
19
+ numCounterfactuals: 3,
20
+ metrictarget: 0.5,
21
+ tau: 0.1,
22
+ iterations: 1000,
23
+ seed: 42,
24
+ enableFineTuning: false,
25
+ counterfactual: false,
26
+ });
27
+
28
+ const [datasetLimit, setDatasetLimit] = useState<number>(10);
29
+ const [customDataset, setCustomDataset] = useState('');
30
+ const [customLM, setCustomLM] = useState('');
31
+ const [showCustomDatasetInput, setShowCustomDatasetInput] = useState(false);
32
+ const [showCustomLanguageInput, setShowCustomLanguageInput] = useState(false);
33
+ const [fieldStats, setFieldStats] = useState<Record<string, Record<string, number>>>({});
34
+ const [numCounterfactuals, setNumCounterfactuals] = useState<number>(3);
35
+ const [classificationTask, setClassificationTask] = useState<'sentiment' | 'regard' | 'stereotype' | 'personality' | 'toxicity'>('sentiment');
36
+ const [toxicityModelChoice, setToxicityModelChoice] = useState<'detoxify' | 'junglelee'>('detoxify');
37
+ const [selectedCfFields, setSelectedCfFields] = useState<string[]>([]);
38
+ const [availableFields, setAvailableFields] = useState<string[]>([]);
39
+ const [isLoadingFields, setIsLoadingFields] = useState(false);
40
+ const [fieldsError, setFieldsError] = useState<string | null>(null);
41
+ const [metaConfigs, setMetaConfigs] = useState<string[]>([]);
42
+ const [metaSplits, setMetaSplits] = useState<string[]>([]);
43
+ const [selectedConfig, setSelectedConfig] = useState<string | null>(null);
44
+ const [selectedSplit, setSelectedSplit] = useState<string>('train');
45
+
46
+ const canStart = !!(cfg.dataset && cfg.languageModel);
47
+ const [ftEpochs, setFtEpochs] = useState(3);
48
+ const [ftBatchSize, setFtBatchSize] = useState(8);
49
+ const [ftLR, setFtLR] = useState(5e-5);
50
+ const setField = <K extends keyof JobConfig>(k: K, v: JobConfig[K]) =>
51
+ setCfg((prev) => ({ ...prev, [k]: v }));
52
+
53
+ const card = 'group relative rounded-2xl p-8 border border-white/30 bg-white/60 backdrop-blur-xl ' +
54
+ 'shadow-[0_15px_40px_-20px_rgba(30,41,59,0.35)] transition-all duration-300 ' +
55
+ 'hover:shadow-[0_20px_50px_-20px_rgba(79,70,229,0.45)] hover:-translate-y-0.5';
56
+
57
+ const sectionTitle = 'text-xl font-bold tracking-tight text-slate-900';
58
+ const subtext = 'text-sm text-slate-600';
59
+ const fieldInput = 'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-4 py-3 ' +
60
+ 'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all';
61
+ const selectInput = 'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-3 py-2.5 ' +
62
+ 'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all';
63
+ const choiceRow = 'flex items-start gap-4 cursor-pointer p-4 rounded-xl border transition-colors ' +
64
+ 'bg-white/60 hover:bg-white/80 border-slate-200/60 hover:border-indigo-300';
65
+
66
+ const currentDataset = DATASETS.find((d) => d.id === cfg.dataset);
67
+ const fallbackFields: string[] = (currentDataset as any)?.fields || ['text', 'label', 'group'];
68
+
69
+ const toggleCfField = (f: string) =>
70
+ setSelectedCfFields((prev) =>
71
+ (prev.includes(f) ? prev.filter((x) => x !== f) : [...prev, f])
72
+ );
73
+
74
+ const API_BASE = '/api';
75
+
76
+ async function fetchJSON<T>(url: string, signal?: AbortSignal): Promise<T> {
77
+ const fullURL = url.startsWith('http') ? url : `${API_BASE}${url}`;
78
+ const res = await fetch(fullURL, { signal });
79
+ if (!res.ok) throw new Error(`${res.status} ${res.statusText}`);
80
+ return (await res.json()) as T;
81
+ }
82
+
83
+
84
+ function buildFieldsURL(datasetId: string, config: string | null, split: string): string {
85
+ const params = new URLSearchParams();
86
+ params.set('id', datasetId);
87
+
88
+ if (config && config.trim() !== '') {
89
+ params.set('config', config);
90
+ }
91
+
92
+ if (split && split.trim() !== '') {
93
+ params.set('split', split);
94
+ }
95
+
96
+ return `/dataset/fields?${params.toString()}`;
97
+ }
98
+
99
+ useEffect(() => {
100
+ console.log('📊 Dataset changed:', cfg.dataset);
101
+ setSelectedCfFields([]);
102
+ setFieldsError(null);
103
+ setAvailableFields([]);
104
+
105
+ if (!cfg.dataset || cfg.dataset === 'custom') return;
106
+
107
+ const ac = new AbortController();
108
+
109
+ const run = async () => {
110
+ try {
111
+ console.log('🔍 Fetching dataset meta...');
112
+ const metaURL = `/dataset/meta?id=${encodeURIComponent(cfg.dataset)}`;
113
+ const meta = await fetchJSON<{
114
+ datasetId: string;
115
+ configs: string[];
116
+ splits: string[];
117
+ }>(metaURL, ac.signal);
118
+
119
+ console.log('📋 Meta data received:', meta);
120
+
121
+ setMetaConfigs(meta.configs || []);
122
+ setMetaSplits(meta.splits || []);
123
+
124
+ const defaultConfig = meta.configs?.length ? meta.configs[0] : null;
125
+ const defaultSplit = meta.splits?.length ?
126
+ (meta.splits.includes('train') ? 'train' : meta.splits[0]) :
127
+ 'train';
128
+
129
+ setSelectedConfig(defaultConfig);
130
+ setSelectedSplit(defaultSplit);
131
+
132
+ console.log('🏷️ Fetching fields with config:', defaultConfig, 'split:', defaultSplit);
133
+ setIsLoadingFields(true);
134
+
135
+ const fieldsURL = buildFieldsURL(cfg.dataset, defaultConfig, defaultSplit);
136
+ const fieldsData = await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal);
137
+
138
+ setAvailableFields(fieldsData.fields || []);
139
+ setFieldsError(null);
140
+
141
+ } catch (err: any) {
142
+ console.error('❌ Error in dataset effect:', err);
143
+
144
+ setMetaConfigs([]);
145
+ setMetaSplits([]);
146
+ setSelectedConfig(null);
147
+ setSelectedSplit('train');
148
+
149
+ setAvailableFields([]);
150
+ const fieldsURL = buildFieldsURL(cfg.dataset, null, 'train');
151
+ setFieldsError(`(${fieldsURL}) → ${err?.message || '欄位讀取失敗'}`);
152
+ } finally {
153
+ setIsLoadingFields(false);
154
+ }
155
+ };
156
+
157
+ run();
158
+ return () => ac.abort();
159
+ }, [cfg.dataset]);
160
+
161
+ useEffect(() => {
162
+ if (!cfg.dataset || cfg.dataset === 'custom') return;
163
+
164
+ console.log('🔄 Config/Split changed - config:', selectedConfig, 'split:', selectedSplit);
165
+
166
+ const ac = new AbortController();
167
+
168
+ const run = async () => {
169
+ try {
170
+ setIsLoadingFields(true);
171
+
172
+ const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit);
173
+ const fieldsData = await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal);
174
+
175
+ setAvailableFields(fieldsData.fields || []);
176
+ setFieldsError(null);
177
+ setSelectedCfFields([]);
178
+
179
+ const statsURL = `/dataset/field-stats?id=${encodeURIComponent(cfg.dataset)}&field=domain&subfield=category`;
180
+ const statsData = await fetchJSON<{ counts: Record<string, Record<string, number>> }>(statsURL, ac.signal);
181
+ setFieldStats(statsData.counts || {});
182
+
183
+ } catch (err: any) {
184
+ console.error('❌ Error fetching fields after config/split change:', err);
185
+ const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit);
186
+ setAvailableFields([]);
187
+ setFieldsError(`(${fieldsURL}) → ${err?.message || 'Field Read Failed'}`);
188
+ } finally {
189
+ setIsLoadingFields(false);
190
+ }
191
+ };
192
+
193
+ run();
194
+ return () => ac.abort();
195
+ }, [cfg.dataset, selectedConfig, selectedSplit]);
196
+
197
+ return (
198
+ <div className="space-y-10">
199
+ <div className="grid grid-cols-1 lg:grid-cols-6 gap-8">
200
+ {/* 數據集選擇 */}
201
+ <div className={`${card} lg:col-span-3`}>
202
+ <div className="flex items-center gap-3 mb-8">
203
+ <div className="p-3 rounded-xl bg-gradient-to-br from-indigo-600 to-fuchsia-600 shadow-md shadow-indigo-600/30">
204
+ <Database className="w-6 h-6 text-white" />
205
+ </div>
206
+ <h3 className={sectionTitle}>Dataset Selection</h3>
207
+ </div>
208
+
209
+ <div className="space-y-4">
210
+ {DATASETS.map((dataset) => (
211
+ <label key={dataset.id} className={choiceRow}>
212
+ <input
213
+ type="radio"
214
+ name="dataset"
215
+ value={dataset.id}
216
+ checked={cfg.dataset === dataset.id}
217
+ onChange={(e) => {
218
+ setField('dataset', e.target.value);
219
+ setShowCustomDatasetInput(false);
220
+ setCustomDataset('');
221
+ setSelectedCfFields([]);
222
+ }}
223
+ className="mt-1 accent-indigo-600"
224
+ />
225
+ <div className="flex-1">
226
+ <div className="font-semibold text-slate-900">{dataset.name}</div>
227
+ <div className="flex items-center gap-4 text-xs text-slate-500 mt-2">
228
+ {'entities' in dataset && (
229
+ <span>📊 {(dataset as any).entities?.toLocaleString?.() || '-'} entities</span>
230
+ )}
231
+ {'groups' in dataset && <span>👥 {(dataset as any).groups || '-'} groups</span>}
232
+ </div>
233
+ <a
234
+ href={`https://huggingface.co/datasets/${dataset.id}`}
235
+ target="_blank"
236
+ rel="noopener noreferrer"
237
+ className="inline-flex items-center gap-1 text-indigo-600 hover:text-indigo-700 text-xs font-medium mt-2"
238
+ onClick={(e) => e.stopPropagation()}
239
+ >
240
+ <ExternalLink className="w-3.5 h-3.5" />
241
+ View on Hugging Face
242
+ </a>
243
+ </div>
244
+ </label>
245
+ ))}
246
+
247
+ {/* 自訂數據集 */}
248
+ <label className={choiceRow}>
249
+ <input
250
+ type="radio"
251
+ name="dataset"
252
+ value="custom"
253
+ checked={cfg.dataset === 'custom'}
254
+ onChange={(e) => {
255
+ setField('dataset', e.target.value);
256
+ setShowCustomDatasetInput(true);
257
+ setSelectedCfFields([]);
258
+ }}
259
+ className="mt-1 accent-fuchsia-600"
260
+ />
261
+ <div className="flex-1">
262
+ <div className="font-semibold text-slate-900">🔧 Custom Dataset Upload from Hugging Face</div>
263
+ </div>
264
+ </label>
265
+
266
+ {showCustomDatasetInput && (
267
+ <div className="pl-6 space-y-3 animate-in slide-in-from-top duration-300">
268
+ <input
269
+ type="text"
270
+ placeholder="Input Hugging Face Dataset ID (e.g. AmazonScience/bold)"
271
+ value={customDataset}
272
+ onChange={(e) => {
273
+ setCustomDataset(e.target.value);
274
+ setField('dataset', e.target.value);
275
+ }}
276
+ className={fieldInput}
277
+ />
278
+ {customDataset && customDataset.includes('/') && (
279
+ <DatasetValidator datasetId={customDataset} />
280
+ )}
281
+ </div>
282
+ )}
283
+
284
+ {cfg.dataset === 'AmazonScience/bold' && !showCustomDatasetInput && (
285
+ <DatasetValidator datasetId="AmazonScience/bold" />
286
+ )}
287
+ </div>
288
+ </div>
289
+
290
+ {/* 反事實分析(置於中間欄) */}
291
+ <div className={`${card} lg:col-span-3`}>
292
+ <div className="flex items-center gap-3 mb-8">
293
+ <div className="p-3 rounded-xl bg-gradient-to-br from-pink-600 to-rose-600 shadow-md shadow-pink-600/30">
294
+ <Shuffle className="w-6 h-6 text-white" />
295
+ </div>
296
+ <h3 className={sectionTitle}>Counterfactual Setting</h3>
297
+ </div>
298
+
299
+ <div className="space-y-6">
300
+
301
+ <div className="pt-2">
302
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
303
+ Number of Counterfactual
304
+ </label>
305
+ <input
306
+ type="number"
307
+ min={1}
308
+ max={20}
309
+ step={1}
310
+ value={numCounterfactuals}
311
+ onChange={(e) => {
312
+ const v = parseInt(e.target.value || '3', 10);
313
+ setNumCounterfactuals(Number.isFinite(v) ? Math.max(1, Math.min(20, v)) : 3);
314
+ }}
315
+ className={fieldInput}
316
+ />
317
+ </div>
318
+
319
+
320
+ {/* Dataset meta(若有 configs/splits 就顯示下拉) */}
321
+ {(metaConfigs.length > 0 || metaSplits.length > 0) && (
322
+ <div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
323
+ {metaConfigs.length > 0 && (
324
+ <div>
325
+ <label className="block text-sm font-semibold text-slate-800 mb-1">Dataset Config</label>
326
+ <select
327
+ value={selectedConfig || ''}
328
+ onChange={(e) => setSelectedConfig(e.target.value || null)}
329
+ className={selectInput}
330
+ >
331
+ {metaConfigs.map((c) => (
332
+ <option key={c} value={c}>{c}</option>
333
+ ))}
334
+ </select>
335
+ </div>
336
+ )}
337
+
338
+ {metaSplits.length > 0 && (
339
+ <div>
340
+ <label className="block text-sm font-semibold text-slate-800 mb-1">Split</label>
341
+ <select
342
+ value={selectedSplit}
343
+ onChange={(e) => setSelectedSplit(e.target.value)}
344
+ className={selectInput}
345
+ >
346
+ {metaSplits.map((s) => (
347
+ <option key={s} value={s}>{s}</option>
348
+ ))}
349
+ </select>
350
+ </div>
351
+ )}
352
+ </div>
353
+ )}
354
+
355
+ {/* 狀態列 */}
356
+ <div className="text-xs text-slate-500 flex items-center gap-2">
357
+ <span>Selected Dataset</span>
358
+ <span className="inline-flex items-center rounded-full bg-slate-800/90 text-white px-2.5 py-1">
359
+ {cfg.dataset || 'Not Selected Yet'}
360
+ </span>
361
+ {selectedConfig && <span className="ml-1">/ {selectedConfig}</span>}
362
+ {selectedSplit && <span className="ml-1">/ {selectedSplit}</span>}
363
+ </div>
364
+
365
+ {/* 欄位清單 */}
366
+ <div>
367
+ <div className="flex items-center justify-between mb-2">
368
+ <div className="text-sm font-semibold text-slate-800">Optional fields</div>
369
+ {isLoadingFields && <span className="text-xs text-slate-500">Loading</span>}
370
+ </div>
371
+
372
+ <div className="space-y-4 max-h-64 overflow-auto pr-1">
373
+ {Object.entries(fieldStats).map(([domain, categories]) => (
374
+ <div key={domain} className="bg-white/50 border border-slate-200 rounded-xl p-3 shadow-sm">
375
+ <div className="font-semibold text-slate-700 text-sm mb-2">{domain}</div>
376
+ <div className="grid grid-cols-1 sm:grid-cols-2 gap-x-4 gap-y-2 pl-1">
377
+ {Object.entries(categories).map(([category, count]) => {
378
+ const fieldKey = `${domain}/${category}`;
379
+ return (
380
+ <label
381
+ key={fieldKey}
382
+ className="flex items-center gap-2 text-sm text-slate-800 hover:bg-white/60 px-2 py-1 rounded-md transition-colors"
383
+ >
384
+ <input
385
+ type="checkbox"
386
+ checked={selectedCfFields.includes(fieldKey)}
387
+ onChange={() =>
388
+ setSelectedCfFields((prev) =>
389
+ prev.includes(fieldKey)
390
+ ? prev.filter((x) => x !== fieldKey)
391
+ : [...prev, fieldKey]
392
+ )
393
+ }
394
+ className="accent-fuchsia-600"
395
+ />
396
+ <span>{category}</span>
397
+ <span className="text-xs text-slate-500">({count})</span>
398
+ </label>
399
+ );
400
+ })}
401
+ </div>
402
+ </div>
403
+ ))}
404
+ </div>
405
+ </div>
406
+ </div>
407
+ </div>
408
+
409
+ {/* 模型選擇(包含 K / datasetLimit 與 metrictarget 的指定位置) */}
410
+ <div className={`${card} lg:col-span-3`}>
411
+ <div className="flex items-center gap-3 mb-8">
412
+ <div className="p-3 rounded-xl bg-gradient-to-br from-emerald-600 to-teal-600 shadow-md shadow-emerald-600/30">
413
+ <Bot className="w-6 h-6 text-white" />
414
+ </div>
415
+ <h3 className={sectionTitle}>Model Selection</h3>
416
+ </div>
417
+
418
+ <div className="space-y-8">
419
+ {/* 語言模型 */}
420
+ <div>
421
+ <label className="block text-sm font-semibold text-slate-800 mb-2">🤖 Language Generation Model</label>
422
+ <select
423
+ value={cfg.languageModel}
424
+ onChange={(e) => {
425
+ setField('languageModel', e.target.value);
426
+ setShowCustomLanguageInput(e.target.value === 'custom');
427
+ }}
428
+ className={selectInput}
429
+ >
430
+ <option value="">Select a Language Model</option>
431
+ {LM_MODELS.map((m) => (
432
+ <option key={m.id} value={m.id}>
433
+ {m.name}({m.provider})
434
+ </option>
435
+ ))}
436
+ <option value="custom">🔧 Custom Model Upload from Hugging Face</option>
437
+ </select>
438
+
439
+ {showCustomLanguageInput && (
440
+ <input
441
+ type="text"
442
+ placeholder="Input Hugging Face Model ID (e.g.:microsoft/DialoGPT-medium)"
443
+ value={customLM}
444
+ onChange={(e) => {
445
+ setCustomLM(e.target.value);
446
+ setField('languageModel', e.target.value);
447
+ }}
448
+ className={`${fieldInput} mt-3`}
449
+ />
450
+ )}
451
+
452
+ {(customLM || cfg.languageModel) && (
453
+ <div className="mt-3">
454
+ <ModelValidator modelId={customLM || cfg.languageModel} type="language" />
455
+ </div>
456
+ )}
457
+
458
+ {/* 語言模型下方:K 與 datasetLimit */}
459
+ <div className="mt-6 space-y-5">
460
+ <div>
461
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
462
+ Number of Candidates
463
+ <span className="ml-2 text-xs font-normal text-slate-500">The number of candidates generated for each entity</span>
464
+ </label>
465
+ <input
466
+ type="number"
467
+ min={1}
468
+ max={20}
469
+ value={cfg.k}
470
+ onChange={(e) => setField('k', parseInt(e.target.value || '0', 10))}
471
+ className={fieldInput}
472
+ />
473
+ </div>
474
+
475
+ <div>
476
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
477
+ Testing Data Limit
478
+ </label>
479
+ <input
480
+ type="number"
481
+ min={1}
482
+ max={10000}
483
+ value={datasetLimit}
484
+ onChange={(e) => setDatasetLimit(parseInt(e.target.value || '0', 10))}
485
+ className={fieldInput}
486
+ />
487
+ </div>
488
+ </div>
489
+ </div>
490
+
491
+ {/* 分類任務(固定選項) */}
492
+ <div className="mt-6">
493
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
494
+ 👻 Feature Extraction Model
495
+ </label>
496
+ <select
497
+ value={classificationTask}
498
+ onChange={(e) => setClassificationTask(e.target.value as any)}
499
+ className={selectInput}
500
+ >
501
+ <option value="sentiment">Sentiment (0–1, Neutral ≈ 0.5)</option>
502
+ <option value="regard">Regard (0–2, Neutral ≈ 1.0)</option>
503
+ <option value="stereotype">Stereotype (0–1, Neutral ≈ 0.0)</option>
504
+ <option value="personality">Personality (0–1, Neutral ≈ 0.2)</option>
505
+ <option value="toxicity">Toxicity (0–1, Neutral ≈ 0.0)</option>
506
+
507
+ </select>
508
+ </div>
509
+
510
+ {/* 毒性模型選擇(只有當任務為 toxicity 時顯示) */}
511
+ {classificationTask === 'toxicity' && (
512
+ <div className="mt-4">
513
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
514
+ Toxicity Model Selection
515
+ </label>
516
+ <select
517
+ value={toxicityModelChoice}
518
+ onChange={(e) => setToxicityModelChoice(e.target.value as any)}
519
+ className={selectInput}
520
+ >
521
+ <option value="detoxify">unitary/toxic-bert(detoxify)</option>
522
+ <option value="junglelee">JungleLee/bert-toxic-comment-classification</option>
523
+ </select>
524
+ </div>
525
+ )}
526
+
527
+ {/* 評分模型下方:目標指標值 */}
528
+ <div className="mt-6">
529
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
530
+ Metric Target Value
531
+ <span className="ml-2 text-xs font-normal text-slate-500">Indicator thresholds used to determine compliance</span>
532
+ </label>
533
+ <input
534
+ type="number"
535
+ min={0}
536
+ max={2}
537
+ step={0.01}
538
+ value={cfg.metrictarget}
539
+ onChange={(e) => setField('metrictarget', parseFloat(e.target.value || '0'))}
540
+ className={fieldInput}
541
+ />
542
+
543
+
544
+ </div>
545
+ </div>
546
+ </div>
547
+ {/* Fine-tuning 設定 */}
548
+ <div className={`${card} lg:col-span-3`}>
549
+ <div className="flex items-center gap-3 mb-8">
550
+ <div className="p-3 rounded-xl bg-gradient-to-br from-orange-500 to-yellow-500 shadow-md shadow-orange-500/30">
551
+ <Database className="w-6 h-6 text-white" />
552
+ </div>
553
+ <h3 className={sectionTitle}>Fine-tuning Setting</h3>
554
+ </div>
555
+
556
+ <div className="space-y-6">
557
+ <label className="flex items-center gap-2">
558
+ <input
559
+ type="checkbox"
560
+ checked={cfg.enableFineTuning}
561
+ onChange={(e) => setField('enableFineTuning', e.target.checked)}
562
+ className="accent-orange-500"
563
+ />
564
+ <span className="text-sm text-slate-800 font-semibold">Enable Fine-tuning</span>
565
+ </label>
566
+
567
+ {cfg.enableFineTuning && (
568
+ <div className="space-y-4 pl-4 border-l-2 border-orange-200">
569
+ {/* Epochs */}
570
+ <div>
571
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
572
+ Training Epochs
573
+ </label>
574
+ <input
575
+ type="number"
576
+ min={1}
577
+ max={100}
578
+ value={ftEpochs}
579
+ onChange={(e) => setFtEpochs(parseInt(e.target.value || '3', 10))}
580
+ className={fieldInput}
581
+ />
582
+ </div>
583
+
584
+ {/* Batch Size */}
585
+ <div>
586
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
587
+ Batch Size
588
+ </label>
589
+ <input
590
+ type="number"
591
+ min={1}
592
+ max={256}
593
+ value={ftBatchSize}
594
+ onChange={(e) => setFtBatchSize(parseInt(e.target.value || '8', 10))}
595
+ className={fieldInput}
596
+ />
597
+ </div>
598
+
599
+ {/* Learning Rate */}
600
+ <div>
601
+ <label className="block text-sm font-semibold text-slate-800 mb-1">
602
+ Learning Rate
603
+ </label>
604
+ <input
605
+ type="number"
606
+ step={0.00001}
607
+ value={ftLR}
608
+ onChange={(e) => setFtLR(parseFloat(e.target.value || '0.00005'))}
609
+ className={fieldInput}
610
+ />
611
+ </div>
612
+ </div>
613
+ )}
614
+ </div>
615
+ </div>
616
+
617
+ </div>
618
+
619
+ {/* 開始按鈕 */}
620
+ <div className="flex">
621
+ <button
622
+ onClick={() => {
623
+ const fullCfg = {
624
+ ...cfg,
625
+ selectedCfFields,
626
+ numCounterfactuals,
627
+ classificationTask,
628
+ toxicityModelChoice,
629
+ finetuneParams: {
630
+ epochs: ftEpochs,
631
+ batchSize: ftBatchSize,
632
+ learningRate: ftLR,
633
+ },
634
+ };
635
+ onRun(fullCfg, {
636
+ datasetLimit
637
+ });
638
+ }}
639
+ disabled={!canStart}
640
+ className="relative w-full group overflow-hidden rounded-2xl px-6 py-4 text-white font-semibold bg-gradient-to-r from-indigo-600 via-violet-600 to-fuchsia-600 shadow-lg shadow-indigo-600/20 enabled:hover:shadow-indigo-600/40 transition-all enabled:hover:translate-y-[-1px] enabled:active:translate-y-0 disabled:opacity-60 disabled:cursor-not-allowed"
641
+ >
642
+ <span className="relative z-10">🚀 Start</span>
643
+ <span className="absolute inset-0 opacity-0 group-hover:opacity-100 transition-opacity bg-[radial-gradient(1200px_200px_at_50%_-40%,rgba(255,255,255,0.35),transparent_60%)]" />
644
+ </button>
645
+ </div>
646
+
647
+ </div>
648
+ );
649
+ }
frontend/src/pages/ResultsPage.tsx ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PipelineProgress from '../components/PipelineProgress';
2
+ import { useJobRunner } from '../hooks/JobRunnerProvider';
3
+
4
+ export default function ResultsPage() {
5
+ const { result, resp, loading, error, url } = useJobRunner();
6
+
7
+ if (loading && !resp) {
8
+ return (
9
+ <div className="space-y-6">
10
+ <PipelineProgress />
11
+ <section className="grid grid-cols-1 md:grid-cols-2 gap-6">
12
+ {[0, 1].map((i) => (
13
+ <div key={i} className="rounded-2xl border bg-white/70 backdrop-blur p-3">
14
+ <div className="w-full h-64 rounded-xl bg-gradient-to-b from-slate-200 to-slate-100 animate-pulse" />
15
+ <div className="mt-3 h-4 w-40 rounded bg-slate-200 animate-pulse" />
16
+ </div>
17
+ ))}
18
+ </section>
19
+ </div>
20
+ );
21
+ }
22
+
23
+ if (error) {
24
+ return (
25
+ <div className="p-6 rounded-2xl bg-red-50 border border-red-200 text-red-700">
26
+ {error}
27
+ </div>
28
+ );
29
+ }
30
+
31
+ if (!result || !resp) {
32
+ return (
33
+ <div className="p-6 rounded-2xl bg-white/70 border border-white/40">
34
+ Task not executed yet
35
+ </div>
36
+ );
37
+ }
38
+
39
+ const m = result.metrics!;
40
+ const plots = resp.results.plots;
41
+
42
+ const originalSrc = url(plots.original_sentiment);
43
+ const cfSrc = url(plots.counterfactual_sentiment);
44
+
45
+ const r = resp.results as any;
46
+ const links: { label: string; href: string }[] = [];
47
+
48
+ if (r?.generation_file) {
49
+ links.push({ label: 'Generation CSV', href: r.generation_file });
50
+ }
51
+ if (r?.sentiment_subset_file) {
52
+ links.push({ label: 'Original sentiment subset CSV', href: r.sentiment_subset_file });
53
+ }
54
+ if (r?.cf_sentiment_subset_file) {
55
+ links.push({ label: 'CF sentiment subset CSV', href: r.cf_sentiment_subset_file });
56
+ }
57
+
58
+ if (r?.run_config_files?.markdown) {
59
+ links.push({ label: 'Run Config (Markdown)', href: r.run_config_files.markdown });
60
+ }
61
+ if (r?.run_config_files?.json) {
62
+ links.push({ label: 'Run Config (JSON)', href: r.run_config_files.json });
63
+ }
64
+
65
+ if (r?.finetuned_model_zip) {
66
+ links.push({ label: 'Fine-tuned Model (ZIP)', href: r.finetuned_model_zip });
67
+ } else if (r?.finetuned_model_dir) {
68
+ links.push({ label: 'Fine-tuned Model Folder', href: r.finetuned_model_dir });
69
+ }
70
+
71
+ return (
72
+ <div className="space-y-6">
73
+ {loading && <PipelineProgress />}
74
+
75
+ <section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
76
+ <h2 className="text-lg font-semibold mb-3">Metric</h2>
77
+ <div className="grid grid-cols-1 sm:grid-cols-2 gap-4">
78
+ <div className="p-4 rounded-xl bg-slate-50 border">
79
+ <div className="text-slate-500 text-sm">Original Difference</div>
80
+ <div className="text-2xl font-bold">{m.finalMeanDiff.toFixed(4)}</div>
81
+ </div>
82
+ <div className="p-4 rounded-xl bg-slate-50 border">
83
+ <div className="text-slate-500 text-sm">CF Difference</div>
84
+ <div className="text-2xl font-bold">
85
+ {resp.results.metrics.cfFinalMeanDiff.toFixed(4)}
86
+ </div>
87
+ </div>
88
+ </div>
89
+ </section>
90
+
91
+ <section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
92
+ <h2 className="text-lg font-semibold mb-4">Distribution</h2>
93
+ <div className="grid grid-cols-1 md:grid-cols-2 gap-6">
94
+ <figure className="rounded-xl overflow-hidden border bg-white">
95
+ <img
96
+ src={originalSrc}
97
+ alt="Original distribution"
98
+ className="w-full h-auto"
99
+ loading="lazy"
100
+ onError={(e) => {
101
+ e.currentTarget.alt = 'Original image loading failed';
102
+ }}
103
+ />
104
+ <figcaption className="p-3 text-sm text-slate-600">Original</figcaption>
105
+ </figure>
106
+
107
+ <figure className="rounded-xl overflow-hidden border bg-white">
108
+ <img
109
+ src={cfSrc}
110
+ alt="Counterfactual distribution"
111
+ className="w-full h-auto"
112
+ loading="lazy"
113
+ onError={(e) => {
114
+ e.currentTarget.alt = 'Counterfactual image loading failed';
115
+ }}
116
+ />
117
+ <figcaption className="p-3 text-sm text-slate-600">Counterfactual Augmented</figcaption>
118
+ </figure>
119
+ </div>
120
+ </section>
121
+
122
+ {links.length > 0 && (
123
+ <section className="p-6 rounded-2xl border border-white/40 bg-white/70 backdrop-blur">
124
+ <h2 className="text-lg font-semibold mb-4">Download Report</h2>
125
+ <ul className="space-y-2">
126
+ {links.map((l) => (
127
+ <li key={l.label}>
128
+ <a
129
+ className="text-indigo-600 hover:underline"
130
+ href={url(l.href)}
131
+ target="_blank"
132
+ rel="noreferrer"
133
+ download
134
+ >
135
+ {l.label}
136
+ </a>
137
+ </li>
138
+ ))}
139
+ </ul>
140
+ </section>
141
+ )}
142
+ </div>
143
+ );
144
+ }
frontend/src/services/api.ts ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/services/api.ts
2
+ import type { JobConfig } from '../types';
3
+
4
+ export type PipelinePlots = {
5
+ original_sentiment: string;
6
+ counterfactual_sentiment: string;
7
+ };
8
+
9
+ export type PipelineResultsDTO = {
10
+ data_loaded: number;
11
+ model_loaded: string;
12
+
13
+ generation_file: string;
14
+ generation_samples: number;
15
+
16
+ counterfactual_file: string;
17
+ counterfactual_added: number;
18
+ counterfactual_total: number;
19
+
20
+ sampling_method: string;
21
+ sentiment_subset_file: string;
22
+ sentiment_subset_size: number;
23
+
24
+ cf_sentiment_subset_file: string;
25
+ cf_sentiment_subset_size: number;
26
+
27
+ // 後端還會給 stereotype 的欄位,但前端不需要可不宣告
28
+ config_used: import('../types').JobConfig;
29
+ metrics: import('../types').JobMetrics & {
30
+ finalMeanDiff: number;
31
+ cfFinalMeanDiff: number;
32
+ };
33
+ plots: PipelinePlots;
34
+ };
35
+
36
+ export type PipelineResponseDTO = {
37
+ status: 'success' | 'error';
38
+ message: string;
39
+ timestamp: string;
40
+ results: PipelineResultsDTO;
41
+ };
42
+
43
+ const BASE = import.meta.env.VITE_API_BASE ?? '/api';
44
+
45
+ async function runPipeline(config: any) {
46
+ const r = await fetch(`${BASE}/pipeline`, {
47
+ method: 'POST',
48
+ headers: { 'Content-Type': 'application/json' },
49
+ body: JSON.stringify({ config }),
50
+ });
51
+ if (!r.ok) {
52
+ const text = await r.text();
53
+ throw new Error(`Pipeline failed (${r.status}): ${text}`);
54
+ }
55
+ return r.json();
56
+ }
57
+
58
+ async function checkHealth() {
59
+ const r = await fetch(`${BASE}/health`);
60
+ return r.json();
61
+ }
62
+
63
+ function resolvePath(p?: string) {
64
+ if (!p) return '';
65
+ if (p.startsWith('http')) return p;
66
+ const path = p.startsWith('/') ? p : `/${p}`;
67
+ return `${BASE}${path}`;
68
+ }
69
+
70
+ export const MLBiasAPI = { runPipeline, checkHealth, resolvePath };
frontend/src/services/hf.ts ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/services/hf.ts
2
+ export async function fetchHFModel(modelId: string) {
3
+ const r = await fetch(`https://huggingface.co/api/models/${modelId}`);
4
+ if (!r.ok) throw new Error('The model does not exist or cannot be accessed');
5
+ return r.json();
6
+ }
7
+
8
+ export async function fetchHFDataset(datasetId: string) {
9
+ const r = await fetch(`https://huggingface.co/api/datasets/${datasetId}`);
10
+ if (!r.ok) throw new Error('The dataset does not exist or cannot be accessed');
11
+ return r.json();
12
+ }
frontend/src/types/index.ts ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // src/types/index.ts
2
+ export type Dataset = {
3
+ id: string;
4
+ name: string;
5
+ };
6
+
7
+ export type Model = {
8
+ id: string;
9
+ name: string;
10
+ type: 'language' | 'scorer';
11
+ description: string;
12
+ provider: string;
13
+ };
14
+
15
+ export type JobConfig = {
16
+ dataset: string;
17
+ languageModel: string;
18
+ scorerModel: string;
19
+ k: number;
20
+ tau: number;
21
+ iterations: number;
22
+ seed: number;
23
+ enableFineTuning: boolean;
24
+ counterfactual: boolean;
25
+ metrictarget: number;
26
+ numCounterfactuals: number;
27
+ selectedCfFields?: string[];
28
+
29
+ };
30
+
31
+ export type JobStatus = 'running' | 'completed' | 'failed';
32
+
33
+ export type ChartPoint = {
34
+ iteration: number;
35
+ meanDifference: number;
36
+ groupA: number;
37
+ groupB: number;
38
+ };
39
+
40
+ export type JobMetrics = {
41
+ finalMeanDiff: number;
42
+ reductionPct: number;
43
+ stableCoverage: number;
44
+ };
45
+
46
+ export type JobResult = {
47
+ id: string;
48
+ status: JobStatus;
49
+ progress: number; // 0-100
50
+ config: JobConfig;
51
+ createdAt: string;
52
+ updatedAt: string;
53
+ completedAt?: string;
54
+ charts?: ChartPoint[];
55
+ metrics?: JobMetrics;
56
+ };
57
+
58
+ export type Extras = {
59
+ datasetLimit: number;
60
+ };
frontend/src/vite-env.d.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ /// <reference types="vite/client" />
frontend/tailwind.config.js ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: [
4
+ "./index.html",
5
+ "./src/**/*.{js,ts,jsx,tsx}",
6
+ ],
7
+ theme: {
8
+ extend: {},
9
+ },
10
+ plugins: [require('@tailwindcss/forms')],
11
+ }
frontend/tsconfig.app.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4
+ "target": "ES2022",
5
+ "useDefineForClassFields": true,
6
+ "lib": ["ES2022", "DOM", "DOM.Iterable"],
7
+ "module": "ESNext",
8
+ "skipLibCheck": true,
9
+
10
+ /* Bundler mode */
11
+ "moduleResolution": "bundler",
12
+ "allowImportingTsExtensions": true,
13
+ "verbatimModuleSyntax": true,
14
+ "moduleDetection": "force",
15
+ "noEmit": true,
16
+ "jsx": "react-jsx",
17
+
18
+ /* Linting */
19
+ "strict": true,
20
+ "noUnusedLocals": false,
21
+ "noUnusedParameters": false,
22
+ "erasableSyntaxOnly": true,
23
+ "noFallthroughCasesInSwitch": true,
24
+ "noUncheckedSideEffectImports": true
25
+ },
26
+ "include": ["src"]
27
+ }
frontend/tsconfig.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "files": [],
3
+ "references": [
4
+ { "path": "./tsconfig.app.json" },
5
+ { "path": "./tsconfig.node.json" }
6
+ ]
7
+ }
8
+
frontend/tsconfig.node.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
4
+ "target": "ES2023",
5
+ "lib": ["ES2023"],
6
+ "module": "ESNext",
7
+ "skipLibCheck": true,
8
+
9
+ /* Bundler mode */
10
+ "moduleResolution": "bundler",
11
+ "allowImportingTsExtensions": true,
12
+ "verbatimModuleSyntax": true,
13
+ "moduleDetection": "force",
14
+ "noEmit": true,
15
+
16
+ /* Linting */
17
+ "strict": true,
18
+ "noUnusedLocals": true,
19
+ "noUnusedParameters": true,
20
+ "erasableSyntaxOnly": true,
21
+ "noFallthroughCasesInSwitch": true,
22
+ "noUncheckedSideEffectImports": true
23
+ },
24
+ "include": ["vite.config.ts"]
25
+ }
frontend/vite.config.ts ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from 'vite'
2
+ import react from '@vitejs/plugin-react'
3
+
4
+ // https://vite.dev/config/
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ })
nginx.conf.template ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ worker_processes auto;
2
+ events {
3
+ worker_connections 1024;
4
+ }
5
+
6
+ http {
7
+ include /etc/nginx/mime.types;
8
+ default_type application/octet-stream;
9
+ sendfile on;
10
+ keepalive_timeout 65;
11
+
12
+ server {
13
+ listen 7860;
14
+ server_name _;
15
+ root /usr/share/nginx/html;
16
+ index index.html;
17
+
18
+ # 前端單頁應用
19
+ location / {
20
+ try_files $uri $uri/ /index.html;
21
+ }
22
+
23
+ # 後端 API 反向代理
24
+ location /api/ {
25
+ proxy_pass http://127.0.0.1:5001/;
26
+ proxy_http_version 1.1;
27
+ proxy_set_header Upgrade $http_upgrade;
28
+ proxy_set_header Connection "upgrade";
29
+ proxy_set_header Host $host;
30
+ proxy_set_header X-Real-IP $remote_addr;
31
+ proxy_buffering off;
32
+
33
+ proxy_connect_timeout 3000s;
34
+ proxy_send_timeout 3000s;
35
+ proxy_read_timeout 3000s;
36
+ proxy_redirect off;
37
+
38
+ }
39
+ }
40
+ }