Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -8,9 +8,10 @@ from huggingface_hub import HfApi, hf_hub_download
|
|
| 8 |
from datetime import datetime, timezone
|
| 9 |
import logging
|
| 10 |
import uvicorn # To run the app
|
|
|
|
| 11 |
|
| 12 |
-
tool_threshold = 3
|
| 13 |
-
step_threshold = 5
|
| 14 |
|
| 15 |
# --- Configuration ---
|
| 16 |
HF_DATASET_ID = "agents-course/unit4-students-scores"
|
|
@@ -28,54 +29,76 @@ def load_questions():
|
|
| 28 |
global questions_for_api
|
| 29 |
global ground_truth_answers
|
| 30 |
tempo_filtered=[]
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
for question in dataset['validation']:
|
| 33 |
metadata = question.get('Annotator Metadata') # Use .get() for safety
|
| 34 |
-
|
| 35 |
if metadata: # Check if 'Annotator Metadata' exists
|
| 36 |
num_tools_str = metadata.get('Number of tools')
|
| 37 |
num_steps_str = metadata.get('Number of steps')
|
| 38 |
-
|
| 39 |
# Check if both numbers exist before trying to convert
|
| 40 |
if num_tools_str is not None and num_steps_str is not None:
|
| 41 |
try:
|
| 42 |
# Convert values to integers for comparison
|
| 43 |
num_tools = int(num_tools_str)
|
| 44 |
num_steps = int(num_steps_str)
|
| 45 |
-
|
| 46 |
# Apply the filter conditions
|
| 47 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
tempo_filtered.append(question) # Add to the filtered list
|
| 52 |
# else: # Optional: Handle items that don't match the filter
|
| 53 |
-
#
|
| 54 |
except ValueError:
|
| 55 |
# Handle cases where 'Number of tools' or 'Number of steps' is not a valid integer
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
filtered_dataset=tempo_filtered
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
for item in filtered_dataset:
|
| 61 |
task_id = item.get('task_id')
|
| 62 |
question_text = item.get('Question')
|
| 63 |
final_answer = item.get('Final answer')
|
| 64 |
-
|
| 65 |
if task_id and question_text and final_answer is not None:
|
| 66 |
questions_for_api.append({
|
| 67 |
"task_id": str(task_id), # Ensure ID is string
|
| 68 |
"question": question_text
|
| 69 |
})
|
| 70 |
ground_truth_answers[str(task_id)] = str(final_answer) # Ensure answer is string
|
|
|
|
| 71 |
else:
|
| 72 |
-
logger.warning(f"Skipping item due to missing fields: {item}")
|
| 73 |
-
|
| 74 |
-
logger.info(f"
|
| 75 |
-
print(questions_for_api)
|
|
|
|
| 76 |
if not questions_for_api:
|
| 77 |
-
logger.error("No valid questions loaded. API
|
| 78 |
-
#
|
|
|
|
| 79 |
|
| 80 |
# --- Pydantic Models for Data Validation ---
|
| 81 |
class Question(BaseModel):
|
|
@@ -115,13 +138,20 @@ async def startup_event():
|
|
| 115 |
Loads the questions when the FastAPI application starts.
|
| 116 |
"""
|
| 117 |
logger.info("Application startup: Loading questions...")
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
# --- Helper Function to interact with HF Dataset ---
|
| 127 |
def update_huggingface_dataset(username: str, score: float):
|
|
@@ -129,24 +159,21 @@ def update_huggingface_dataset(username: str, score: float):
|
|
| 129 |
try:
|
| 130 |
# 1. Load the dataset
|
| 131 |
logger.info(f"Loading dataset '{HF_DATASET_ID}'...")
|
| 132 |
-
|
| 133 |
try:
|
| 134 |
# Use hf_hub_download to check if the parquet file exists, avoiding full dataset load error if empty
|
| 135 |
# This assumes the dataset uses the default 'train' split and parquet format. Adjust if needed.
|
| 136 |
hf_hub_download(repo_id=HF_DATASET_ID, filename="data/train-00000-of-00001.parquet", repo_type="dataset")
|
| 137 |
-
|
| 138 |
logger.info("Dataset loaded successfully.")
|
| 139 |
-
|
| 140 |
-
if "train" not in ds:
|
| 141 |
logger.warning(f"Dataset '{HF_DATASET_ID}' does not contain a 'train' split. Creating one.")
|
| 142 |
-
# Create an empty DataFrame with the correct schema if 'train' split is missing
|
| 143 |
df = pd.DataFrame({'username': pd.Series(dtype='str'),
|
| 144 |
'score': pd.Series(dtype='float'),
|
| 145 |
'timestamp': pd.Series(dtype='str')})
|
| 146 |
-
ds = DatasetDict({'train': Dataset.from_pandas(df)})
|
| 147 |
else:
|
| 148 |
# Convert the 'train' split to a pandas DataFrame for easier manipulation
|
| 149 |
-
df =
|
| 150 |
|
| 151 |
except Exception as load_error: # Catch broad exception for file not found or other loading issues
|
| 152 |
logger.warning(f"Could not load dataset '{HF_DATASET_ID}' or it might be empty/new ({load_error}). Creating structure.")
|
|
@@ -165,6 +192,8 @@ def update_huggingface_dataset(username: str, score: float):
|
|
| 165 |
|
| 166 |
# Convert score column to numeric, coercing errors
|
| 167 |
df['score'] = pd.to_numeric(df['score'], errors='coerce')
|
|
|
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
# 2. Find existing score for the user
|
|
@@ -174,9 +203,9 @@ def update_huggingface_dataset(username: str, score: float):
|
|
| 174 |
|
| 175 |
if not existing_entries.empty:
|
| 176 |
# User exists, find their highest score
|
| 177 |
-
# Handle potential NaN scores from coercion or previous bad data
|
| 178 |
max_existing_score = existing_entries['score'].max()
|
| 179 |
-
if
|
| 180 |
logger.info(f"New score {score} is higher than existing max {max_existing_score} for {username}. Updating.")
|
| 181 |
# Remove old entries for this user
|
| 182 |
df = df[df['username'] != username]
|
|
@@ -199,10 +228,16 @@ def update_huggingface_dataset(username: str, score: float):
|
|
| 199 |
# Convert potentially modified DataFrame back to a Dataset object
|
| 200 |
# Ensure the schema matches if columns were added/modified.
|
| 201 |
# Use 'train' split convention.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
updated_ds = DatasetDict({'train': Dataset.from_pandas(df)})
|
| 203 |
-
|
| 204 |
-
#updated_ds.push_to_hub(HF_DATASET_ID) # Token should be picked up from env or login
|
| 205 |
-
logger.
|
|
|
|
| 206 |
return True
|
| 207 |
else:
|
| 208 |
return False # No update was pushed
|
|
@@ -217,17 +252,41 @@ def update_huggingface_dataset(username: str, score: float):
|
|
| 217 |
|
| 218 |
@app.get("/questions",
|
| 219 |
response_model=List[Question],
|
| 220 |
-
summary="Get Filtered Questions",
|
| 221 |
-
description="Returns
|
| 222 |
async def get_questions():
|
| 223 |
"""
|
| 224 |
Provides the list of questions that agents should answer.
|
| 225 |
"""
|
| 226 |
-
print(questions_for_api)
|
| 227 |
if not questions_for_api:
|
|
|
|
| 228 |
raise HTTPException(status_code=404, detail="No questions available.")
|
| 229 |
return questions_for_api
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
@app.post("/submit",
|
| 233 |
response_model=ScoreResponse,
|
|
@@ -236,7 +295,7 @@ async def get_questions():
|
|
| 236 |
responses={
|
| 237 |
200: {"description": "Submission successful, score calculated."},
|
| 238 |
400: {"model": ErrorResponse, "description": "Invalid input data."},
|
| 239 |
-
404: {"model": ErrorResponse, "description": "Task ID not found."},
|
| 240 |
500: {"model": ErrorResponse, "description": "Server error (e.g., failed to update dataset)."}
|
| 241 |
})
|
| 242 |
async def submit_answers(submission: Submission = Body(...)):
|
|
@@ -260,7 +319,8 @@ async def submit_answers(submission: Submission = Body(...)):
|
|
| 260 |
|
| 261 |
|
| 262 |
correct_count = 0
|
| 263 |
-
|
|
|
|
| 264 |
processed_ids = set()
|
| 265 |
|
| 266 |
for answer_item in submission.answers:
|
|
@@ -270,23 +330,20 @@ async def submit_answers(submission: Submission = Body(...)):
|
|
| 270 |
# Prevent duplicate task_id submissions in the same request
|
| 271 |
if task_id in processed_ids:
|
| 272 |
logger.warning(f"Duplicate task_id '{task_id}' in submission from {submission.username}. Skipping.")
|
| 273 |
-
|
| 274 |
-
continue
|
| 275 |
processed_ids.add(task_id)
|
| 276 |
|
| 277 |
|
| 278 |
-
# Check if task_id is valid
|
| 279 |
if task_id not in ground_truth_answers:
|
| 280 |
-
logger.warning(f"Task ID '{task_id}' submitted by {submission.username} not found in ground truth list.")
|
| 281 |
-
#
|
| 282 |
-
# raise HTTPException(status_code=404, detail=f"Task ID '{task_id}' not found.")
|
| 283 |
-
# Option 2: Skip this answer and continue scoring others (chosen here)
|
| 284 |
-
total_attempted -= 1 # Don't count this attempt if the ID was invalid
|
| 285 |
continue
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
ground_truth = ground_truth_answers[task_id]
|
|
|
|
| 290 |
if submitted.strip().lower() == ground_truth.strip().lower():
|
| 291 |
correct_count += 1
|
| 292 |
logger.debug(f"Correct answer for {task_id} from {submission.username}")
|
|
@@ -294,15 +351,17 @@ async def submit_answers(submission: Submission = Body(...)):
|
|
| 294 |
logger.debug(f"Incorrect answer for {task_id} from {submission.username}. Submitted: '{submitted}', Expected: '{ground_truth}'")
|
| 295 |
|
| 296 |
|
| 297 |
-
# Calculate score
|
| 298 |
-
if
|
| 299 |
score = 0.0
|
| 300 |
-
message = "
|
| 301 |
-
logger.warning(f"No valid answers processed for {submission.username}.")
|
| 302 |
else:
|
| 303 |
-
score = round((correct_count /
|
| 304 |
-
message = f"Score calculated successfully
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
# Update Hugging Face dataset
|
|
@@ -328,21 +387,29 @@ async def submit_answers(submission: Submission = Body(...)):
|
|
| 328 |
username=submission.username,
|
| 329 |
score=score,
|
| 330 |
correct_count=correct_count,
|
| 331 |
-
|
|
|
|
| 332 |
message=message,
|
| 333 |
timestamp=datetime.now(timezone.utc).isoformat()
|
| 334 |
)
|
|
|
|
| 335 |
# --- Run the application ---
|
| 336 |
# This part is mainly for local development without Docker.
|
| 337 |
# Docker uses the CMD instruction in the Dockerfile.
|
| 338 |
if __name__ == "__main__":
|
| 339 |
logger.info("Starting FastAPI server for local development...")
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
local_port = int(os.getenv("PORT", "8000"))
|
| 345 |
-
logger.info(f"Running Uvicorn locally on port: {local_port}")
|
| 346 |
-
# Note: host='127.0.0.1' is usually fine for local runs outside docker
|
| 347 |
load_questions()
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from datetime import datetime, timezone
|
| 9 |
import logging
|
| 10 |
import uvicorn # To run the app
|
| 11 |
+
import random # <-- Added import for random choice
|
| 12 |
|
| 13 |
+
tool_threshold = 3
|
| 14 |
+
step_threshold = 5
|
| 15 |
|
| 16 |
# --- Configuration ---
|
| 17 |
HF_DATASET_ID = "agents-course/unit4-students-scores"
|
|
|
|
| 29 |
global questions_for_api
|
| 30 |
global ground_truth_answers
|
| 31 |
tempo_filtered=[]
|
| 32 |
+
# Clear existing data to prevent duplication if called multiple times
|
| 33 |
+
questions_for_api.clear()
|
| 34 |
+
ground_truth_answers.clear()
|
| 35 |
+
|
| 36 |
+
logger.info("Starting to load and filter GAIA dataset...")
|
| 37 |
+
try:
|
| 38 |
+
dataset=load_dataset("gaia-benchmark/GAIA","2023_level1",trust_remote_code=True)
|
| 39 |
+
logger.info("GAIA dataset loaded.")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
| 42 |
+
# Decide how to handle this: maybe raise the error or exit
|
| 43 |
+
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
| 44 |
+
|
| 45 |
for question in dataset['validation']:
|
| 46 |
metadata = question.get('Annotator Metadata') # Use .get() for safety
|
| 47 |
+
|
| 48 |
if metadata: # Check if 'Annotator Metadata' exists
|
| 49 |
num_tools_str = metadata.get('Number of tools')
|
| 50 |
num_steps_str = metadata.get('Number of steps')
|
| 51 |
+
|
| 52 |
# Check if both numbers exist before trying to convert
|
| 53 |
if num_tools_str is not None and num_steps_str is not None:
|
| 54 |
try:
|
| 55 |
# Convert values to integers for comparison
|
| 56 |
num_tools = int(num_tools_str)
|
| 57 |
num_steps = int(num_steps_str)
|
| 58 |
+
|
| 59 |
# Apply the filter conditions
|
| 60 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
| 61 |
+
# logger.debug(f"MATCH FOUND (Task ID: {question.get('task_id', 'N/A')}) - Tools: {num_tools}, Steps: {num_steps}")
|
| 62 |
+
# logger.debug(question) # Print the matching question dictionary
|
| 63 |
+
# logger.debug("------------------------------------------------------------------")
|
| 64 |
tempo_filtered.append(question) # Add to the filtered list
|
| 65 |
# else: # Optional: Handle items that don't match the filter
|
| 66 |
+
# logger.debug(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Tools: {num_tools}, Steps: {num_steps}")
|
| 67 |
except ValueError:
|
| 68 |
# Handle cases where 'Number of tools' or 'Number of steps' is not a valid integer
|
| 69 |
+
logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Could not convert tool/step count to integer: tools='{num_tools_str}', steps='{num_steps_str}'.")
|
| 70 |
+
# logger.debug("------------------------------------------------------------------")
|
| 71 |
+
else:
|
| 72 |
+
logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
|
| 73 |
+
# logger.debug("------------------------------------------------------------------")
|
| 74 |
+
|
| 75 |
filtered_dataset=tempo_filtered
|
| 76 |
+
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
|
| 77 |
+
# print(filtered_dataset) # Keep this commented unless debugging
|
| 78 |
+
|
| 79 |
+
processed_count = 0
|
| 80 |
for item in filtered_dataset:
|
| 81 |
task_id = item.get('task_id')
|
| 82 |
question_text = item.get('Question')
|
| 83 |
final_answer = item.get('Final answer')
|
| 84 |
+
|
| 85 |
if task_id and question_text and final_answer is not None:
|
| 86 |
questions_for_api.append({
|
| 87 |
"task_id": str(task_id), # Ensure ID is string
|
| 88 |
"question": question_text
|
| 89 |
})
|
| 90 |
ground_truth_answers[str(task_id)] = str(final_answer) # Ensure answer is string
|
| 91 |
+
processed_count += 1
|
| 92 |
else:
|
| 93 |
+
logger.warning(f"Skipping item due to missing fields (task_id, Question, or Final answer): {item}")
|
| 94 |
+
|
| 95 |
+
logger.info(f"Successfully processed and loaded {processed_count} questions for the API.")
|
| 96 |
+
# print(questions_for_api) # Keep this commented unless debugging
|
| 97 |
+
|
| 98 |
if not questions_for_api:
|
| 99 |
+
logger.error("CRITICAL: No valid questions loaded after filtering. API endpoints needing questions will fail.")
|
| 100 |
+
# Depending on requirements, you might want to exit or raise an error here
|
| 101 |
+
# raise RuntimeError("Failed to load mandatory question data after filtering.")
|
| 102 |
|
| 103 |
# --- Pydantic Models for Data Validation ---
|
| 104 |
class Question(BaseModel):
|
|
|
|
| 138 |
Loads the questions when the FastAPI application starts.
|
| 139 |
"""
|
| 140 |
logger.info("Application startup: Loading questions...")
|
| 141 |
+
try:
|
| 142 |
+
load_questions() # Call your loading function here
|
| 143 |
+
if not questions_for_api:
|
| 144 |
+
logger.error("CRITICAL: No questions were loaded during startup. The /questions and /random-question endpoints might fail.")
|
| 145 |
+
# Depending on requirements, you might want the app to fail startup
|
| 146 |
+
# raise RuntimeError("Failed to load mandatory question data.")
|
| 147 |
+
else:
|
| 148 |
+
logger.info(f"Successfully loaded {len(questions_for_api)} questions.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"CRITICAL ERROR DURING STARTUP while loading questions: {e}", exc_info=True)
|
| 151 |
+
# Decide if the app should exit if loading fails
|
| 152 |
+
# import sys
|
| 153 |
+
# sys.exit(1)
|
| 154 |
+
|
| 155 |
|
| 156 |
# --- Helper Function to interact with HF Dataset ---
|
| 157 |
def update_huggingface_dataset(username: str, score: float):
|
|
|
|
| 159 |
try:
|
| 160 |
# 1. Load the dataset
|
| 161 |
logger.info(f"Loading dataset '{HF_DATASET_ID}'...")
|
| 162 |
+
ds_dict = None
|
| 163 |
try:
|
| 164 |
# Use hf_hub_download to check if the parquet file exists, avoiding full dataset load error if empty
|
| 165 |
# This assumes the dataset uses the default 'train' split and parquet format. Adjust if needed.
|
| 166 |
hf_hub_download(repo_id=HF_DATASET_ID, filename="data/train-00000-of-00001.parquet", repo_type="dataset")
|
| 167 |
+
ds_dict = load_dataset(HF_DATASET_ID)
|
| 168 |
logger.info("Dataset loaded successfully.")
|
| 169 |
+
if "train" not in ds_dict:
|
|
|
|
| 170 |
logger.warning(f"Dataset '{HF_DATASET_ID}' does not contain a 'train' split. Creating one.")
|
|
|
|
| 171 |
df = pd.DataFrame({'username': pd.Series(dtype='str'),
|
| 172 |
'score': pd.Series(dtype='float'),
|
| 173 |
'timestamp': pd.Series(dtype='str')})
|
|
|
|
| 174 |
else:
|
| 175 |
# Convert the 'train' split to a pandas DataFrame for easier manipulation
|
| 176 |
+
df = ds_dict['train'].to_pandas()
|
| 177 |
|
| 178 |
except Exception as load_error: # Catch broad exception for file not found or other loading issues
|
| 179 |
logger.warning(f"Could not load dataset '{HF_DATASET_ID}' or it might be empty/new ({load_error}). Creating structure.")
|
|
|
|
| 192 |
|
| 193 |
# Convert score column to numeric, coercing errors
|
| 194 |
df['score'] = pd.to_numeric(df['score'], errors='coerce')
|
| 195 |
+
# Fill potential NaN values in score with 0.0 before comparison/aggregation
|
| 196 |
+
df['score'] = df['score'].fillna(0.0)
|
| 197 |
|
| 198 |
|
| 199 |
# 2. Find existing score for the user
|
|
|
|
| 203 |
|
| 204 |
if not existing_entries.empty:
|
| 205 |
# User exists, find their highest score
|
| 206 |
+
# Handle potential NaN scores from coercion or previous bad data (though fillna above should help)
|
| 207 |
max_existing_score = existing_entries['score'].max()
|
| 208 |
+
if score > max_existing_score:
|
| 209 |
logger.info(f"New score {score} is higher than existing max {max_existing_score} for {username}. Updating.")
|
| 210 |
# Remove old entries for this user
|
| 211 |
df = df[df['username'] != username]
|
|
|
|
| 228 |
# Convert potentially modified DataFrame back to a Dataset object
|
| 229 |
# Ensure the schema matches if columns were added/modified.
|
| 230 |
# Use 'train' split convention.
|
| 231 |
+
# Make sure the dtypes are correct before creating the Dataset
|
| 232 |
+
df['username'] = df['username'].astype(str)
|
| 233 |
+
df['score'] = df['score'].astype(float)
|
| 234 |
+
df['timestamp'] = df['timestamp'].astype(str)
|
| 235 |
+
|
| 236 |
updated_ds = DatasetDict({'train': Dataset.from_pandas(df)})
|
| 237 |
+
logger.info(f"Dataset to push: {updated_ds}") # Log the dataset structure
|
| 238 |
+
# updated_ds.push_to_hub(HF_DATASET_ID) # Token should be picked up from env or login
|
| 239 |
+
logger.warning("Dataset push to hub is currently commented out. Uncomment the line above to enable leaderboard updates.") # REMINDER
|
| 240 |
+
logger.info("Dataset push simulated/attempted.")
|
| 241 |
return True
|
| 242 |
else:
|
| 243 |
return False # No update was pushed
|
|
|
|
| 252 |
|
| 253 |
@app.get("/questions",
|
| 254 |
response_model=List[Question],
|
| 255 |
+
summary="Get All Filtered Questions",
|
| 256 |
+
description="Returns the complete list of questions (task_id and question text only) filtered based on criteria.")
|
| 257 |
async def get_questions():
|
| 258 |
"""
|
| 259 |
Provides the list of questions that agents should answer.
|
| 260 |
"""
|
| 261 |
+
# print(f"Returning {len(questions_for_api)} questions.") # Debug log
|
| 262 |
if not questions_for_api:
|
| 263 |
+
logger.error("GET /questions requested but no questions are loaded.")
|
| 264 |
raise HTTPException(status_code=404, detail="No questions available.")
|
| 265 |
return questions_for_api
|
| 266 |
|
| 267 |
+
# --- NEW ENDPOINT ---
|
| 268 |
+
@app.get("/random-question",
|
| 269 |
+
response_model=Question,
|
| 270 |
+
summary="Get One Random Question",
|
| 271 |
+
description="Returns a single random question from the available filtered set.",
|
| 272 |
+
responses={
|
| 273 |
+
200: {"description": "A random question."},
|
| 274 |
+
404: {"model": ErrorResponse, "description": "No questions available to choose from."}
|
| 275 |
+
})
|
| 276 |
+
async def get_random_question():
|
| 277 |
+
"""
|
| 278 |
+
Provides a single, randomly selected question from the loaded list.
|
| 279 |
+
"""
|
| 280 |
+
if not questions_for_api:
|
| 281 |
+
logger.warning("GET /random-question requested but no questions are loaded.")
|
| 282 |
+
raise HTTPException(status_code=404, detail="No questions available to choose from.")
|
| 283 |
+
|
| 284 |
+
# Select and return a random question dictionary
|
| 285 |
+
random_question = random.choice(questions_for_api)
|
| 286 |
+
logger.info(f"Returning random question with task_id: {random_question['task_id']}")
|
| 287 |
+
return random_question
|
| 288 |
+
# --- END NEW ENDPOINT ---
|
| 289 |
+
|
| 290 |
|
| 291 |
@app.post("/submit",
|
| 292 |
response_model=ScoreResponse,
|
|
|
|
| 295 |
responses={
|
| 296 |
200: {"description": "Submission successful, score calculated."},
|
| 297 |
400: {"model": ErrorResponse, "description": "Invalid input data."},
|
| 298 |
+
404: {"model": ErrorResponse, "description": "Task ID not found in submission or ground truth."},
|
| 299 |
500: {"model": ErrorResponse, "description": "Server error (e.g., failed to update dataset)."}
|
| 300 |
})
|
| 301 |
async def submit_answers(submission: Submission = Body(...)):
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
correct_count = 0
|
| 322 |
+
total_attempted_in_payload = len(submission.answers)
|
| 323 |
+
valid_attempted_count = 0 # Count attempts where task_id was valid
|
| 324 |
processed_ids = set()
|
| 325 |
|
| 326 |
for answer_item in submission.answers:
|
|
|
|
| 330 |
# Prevent duplicate task_id submissions in the same request
|
| 331 |
if task_id in processed_ids:
|
| 332 |
logger.warning(f"Duplicate task_id '{task_id}' in submission from {submission.username}. Skipping.")
|
| 333 |
+
continue # Don't count this as an attempt for scoring
|
|
|
|
| 334 |
processed_ids.add(task_id)
|
| 335 |
|
| 336 |
|
| 337 |
+
# Check if task_id is valid (exists in our loaded ground truth)
|
| 338 |
if task_id not in ground_truth_answers:
|
| 339 |
+
logger.warning(f"Task ID '{task_id}' submitted by {submission.username} not found in ground truth list. Skipping this answer.")
|
| 340 |
+
# Don't count this as a valid attempt for score calculation
|
|
|
|
|
|
|
|
|
|
| 341 |
continue
|
| 342 |
|
| 343 |
+
# If we reach here, the task_id is valid
|
| 344 |
+
valid_attempted_count += 1
|
| 345 |
ground_truth = ground_truth_answers[task_id]
|
| 346 |
+
# Compare answers (case-insensitive, strip whitespace)
|
| 347 |
if submitted.strip().lower() == ground_truth.strip().lower():
|
| 348 |
correct_count += 1
|
| 349 |
logger.debug(f"Correct answer for {task_id} from {submission.username}")
|
|
|
|
| 351 |
logger.debug(f"Incorrect answer for {task_id} from {submission.username}. Submitted: '{submitted}', Expected: '{ground_truth}'")
|
| 352 |
|
| 353 |
|
| 354 |
+
# Calculate score based on valid attempts
|
| 355 |
+
if valid_attempted_count == 0:
|
| 356 |
score = 0.0
|
| 357 |
+
message = f"Submission received, but no valid/matching task IDs were found in the {total_attempted_in_payload} answers provided."
|
| 358 |
+
logger.warning(f"No valid answers processed for {submission.username} out of {total_attempted_in_payload} submitted.")
|
| 359 |
else:
|
| 360 |
+
score = round((correct_count / valid_attempted_count) * 100, 2)
|
| 361 |
+
message = f"Score calculated successfully: {correct_count}/{valid_attempted_count} correct answers for valid tasks."
|
| 362 |
+
if valid_attempted_count < total_attempted_in_payload:
|
| 363 |
+
message += f" ({total_attempted_in_payload - valid_attempted_count} submitted answers had invalid or duplicate task IDs)."
|
| 364 |
+
logger.info(f"Score for {submission.username}: {score}% ({correct_count}/{valid_attempted_count})")
|
| 365 |
|
| 366 |
|
| 367 |
# Update Hugging Face dataset
|
|
|
|
| 387 |
username=submission.username,
|
| 388 |
score=score,
|
| 389 |
correct_count=correct_count,
|
| 390 |
+
# Return the count of *valid* attempts for clarity
|
| 391 |
+
total_attempted=valid_attempted_count,
|
| 392 |
message=message,
|
| 393 |
timestamp=datetime.now(timezone.utc).isoformat()
|
| 394 |
)
|
| 395 |
+
|
| 396 |
# --- Run the application ---
|
| 397 |
# This part is mainly for local development without Docker.
|
| 398 |
# Docker uses the CMD instruction in the Dockerfile.
|
| 399 |
if __name__ == "__main__":
|
| 400 |
logger.info("Starting FastAPI server for local development...")
|
| 401 |
+
# Explicitly call load_questions here for local run,
|
| 402 |
+
# as the @app.on_event("startup") might not trigger reliably
|
| 403 |
+
# depending on how uvicorn is invoked directly.
|
| 404 |
+
try:
|
|
|
|
|
|
|
|
|
|
| 405 |
load_questions()
|
| 406 |
+
if not questions_for_api:
|
| 407 |
+
logger.error("EXITING: Cannot start server without loaded questions.")
|
| 408 |
+
else:
|
| 409 |
+
# Read port from environment variable for consistency, default to 8000 for local if not set
|
| 410 |
+
local_port = int(os.getenv("PORT", "8000"))
|
| 411 |
+
logger.info(f"Running Uvicorn locally on http://127.0.0.1:{local_port}")
|
| 412 |
+
# Note: host='127.0.0.1' is usually fine for local runs outside docker
|
| 413 |
+
uvicorn.run(app, host="127.0.0.1", port=local_port, log_level="info")
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.error(f"Failed to start server: {e}", exc_info=True)
|