thadillo
Security hardening and HuggingFace deployment fixes
d038974
from flask import Blueprint, render_template, request, redirect, url_for, session, flash, jsonify, send_file
from app.models.models import Token, Submission, Settings, TrainingExample, FineTuningRun, SubmissionSentence
from app import db
from app.analyzer import get_analyzer
from app.utils.pdf_export import DashboardPDFExporter
from functools import wraps
from typing import Dict
import json
import csv
import io
from datetime import datetime
import os
import logging
logger = logging.getLogger(__name__)
bp = Blueprint('admin', __name__, url_prefix='/admin')
CONTRIBUTOR_TYPES = [
{'value': 'government', 'label': 'Government Officer', 'description': 'Public sector representatives'},
{'value': 'community', 'label': 'Community Member', 'description': 'Local residents and community leaders'},
{'value': 'industry', 'label': 'Industry Representative', 'description': 'Business and industry stakeholders'},
{'value': 'ngo', 'label': 'NGO/Non-Profit', 'description': 'Civil society organizations'},
{'value': 'academic', 'label': 'Academic/Researcher', 'description': 'Universities and research institutions'},
{'value': 'other', 'label': 'Other Stakeholder', 'description': 'Other interested parties'}
]
CATEGORIES = ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions']
def admin_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if 'token' not in session or session.get('type') != 'admin':
return redirect(url_for('auth.login'))
return f(*args, **kwargs)
return decorated_function
@bp.route('/overview')
@admin_required
def overview():
total_submissions = Submission.query.count()
total_tokens = Token.query.filter(Token.type != 'admin').count()
flagged_count = Submission.query.filter_by(flagged_as_offensive=True).count()
unanalyzed_count = Submission.query.filter_by(category=None).count()
submission_open = Settings.get_setting('submission_open', 'true') == 'true'
token_generation_enabled = Settings.get_setting('token_generation_enabled', 'true') == 'true'
analyzed = Submission.query.filter(Submission.category != None).count() > 0
return render_template('admin/overview.html',
total_submissions=total_submissions,
total_tokens=total_tokens,
flagged_count=flagged_count,
unanalyzed_count=unanalyzed_count,
submission_open=submission_open,
token_generation_enabled=token_generation_enabled,
analyzed=analyzed)
@bp.route('/registration')
@admin_required
def registration():
token_generation_enabled = Settings.get_setting('token_generation_enabled', 'true') == 'true'
recent_tokens = Token.query.filter(Token.type != 'admin').order_by(Token.created_at.desc()).limit(10).all()
registration_url = request.host_url.rstrip('/') + url_for('auth.generate')
return render_template('admin/registration.html',
token_generation_enabled=token_generation_enabled,
recent_tokens=recent_tokens,
registration_url=registration_url)
@bp.route('/tokens')
@admin_required
def tokens():
all_tokens = Token.query.all()
return render_template('admin/tokens.html',
tokens=all_tokens,
contributor_types=CONTRIBUTOR_TYPES)
@bp.route('/submissions')
@admin_required
def submissions():
category_filter = request.args.get('category', 'all')
flagged_only = request.args.get('flagged', 'false') == 'true'
query = Submission.query
if category_filter != 'all':
query = query.filter_by(category=category_filter)
if flagged_only:
query = query.filter_by(flagged_as_offensive=True)
all_submissions = query.order_by(Submission.timestamp.desc()).all()
flagged_count = Submission.query.filter_by(flagged_as_offensive=True).count()
analyzed = Submission.query.filter(Submission.category != None).count() > 0
return render_template('admin/submissions.html',
submissions=all_submissions,
categories=CATEGORIES,
category_filter=category_filter,
flagged_only=flagged_only,
flagged_count=flagged_count,
analyzed=analyzed)
@bp.route('/dashboard')
@admin_required
def dashboard():
# Check if analyzed
analyzed = Submission.query.filter(Submission.category != None).count() > 0
if not analyzed:
flash('Please analyze submissions first', 'warning')
return redirect(url_for('admin.overview'))
# Get view mode from query param ('submissions' or 'sentences')
view_mode = request.args.get('mode', 'submissions')
# Contributor stats (unchanged - always submission-based)
contributor_stats = db.session.query(
Submission.contributor_type,
db.func.count(Submission.id)
).group_by(Submission.contributor_type).all()
# MODE DEPENDENT: Data changes based on sentence vs submission view
if view_mode == 'sentences':
# SENTENCE-LEVEL VIEW
# Get all sentences with categories joined with their parent submissions
sentences_query = db.session.query(SubmissionSentence, Submission).join(
Submission
).filter(
SubmissionSentence.category != None
).all()
# Create enhanced sentence objects with submission data
sentences = []
for sentence, submission in sentences_query:
# Create object with both sentence and submission attributes
class EnhancedSentence:
def __init__(self, sentence, submission):
self.id = sentence.id
self.text = sentence.text
self.message = sentence.text # For template compatibility
self.category = sentence.category
self.confidence = sentence.confidence
self.contributor_type = submission.contributor_type
self.timestamp = submission.timestamp
self.latitude = submission.latitude
self.longitude = submission.longitude
self.submission_id = submission.id
sentences.append(EnhancedSentence(sentence, submission))
# Category stats
category_stats = db.session.query(
SubmissionSentence.category,
db.func.count(SubmissionSentence.id)
).filter(SubmissionSentence.category != None).group_by(SubmissionSentence.category).all()
# Breakdown by contributor (via parent submission)
breakdown = {}
for cat in CATEGORIES:
breakdown[cat] = {}
for ctype in CONTRIBUTOR_TYPES:
count = db.session.query(db.func.count(SubmissionSentence.id)).join(
Submission
).filter(
SubmissionSentence.category == cat,
Submission.contributor_type == ctype['value']
).scalar()
breakdown[cat][ctype['value']] = count
# Geotagged sentences (inherit location from parent submission)
geotagged_items = db.session.query(SubmissionSentence, Submission).join(
Submission
).filter(
Submission.latitude != None,
Submission.longitude != None,
SubmissionSentence.category != None
).all()
# Create sentence objects with location data
geotagged_data = []
for sentence, submission in geotagged_items:
# Create a pseudo-object that has both sentence and location data
class SentenceWithLocation:
def __init__(self, sentence, submission):
self.id = sentence.id
self.text = sentence.text
self.category = sentence.category
self.latitude = submission.latitude
self.longitude = submission.longitude
self.contributor_type = submission.contributor_type
self.timestamp = submission.timestamp
self.message = sentence.text # For compatibility
geotagged_data.append(SentenceWithLocation(sentence, submission))
# Items for contributions list (sentences)
items_by_category = sentences
else:
# SUBMISSION-LEVEL VIEW (default)
# Get all submissions with categories
submissions = Submission.query.filter(Submission.category != None).all()
# Category stats
category_stats = db.session.query(
Submission.category,
db.func.count(Submission.id)
).filter(Submission.category != None).group_by(Submission.category).all()
# Breakdown by contributor type
breakdown = {}
for cat in CATEGORIES:
breakdown[cat] = {}
for ctype in CONTRIBUTOR_TYPES:
count = Submission.query.filter_by(
category=cat,
contributor_type=ctype['value']
).count()
breakdown[cat][ctype['value']] = count
# Geotagged submissions
geotagged_data = Submission.query.filter(
Submission.latitude != None,
Submission.longitude != None,
Submission.category != None
).all()
# Items for contributions list (submissions)
items_by_category = submissions
return render_template('admin/dashboard.html',
items=items_by_category,
contributor_stats=contributor_stats,
category_stats=category_stats,
geotagged_items=geotagged_data,
categories=CATEGORIES,
contributor_types=CONTRIBUTOR_TYPES,
breakdown=breakdown,
view_mode=view_mode)
@bp.route('/dashboard/export-pdf')
@admin_required
def export_dashboard_pdf():
"""Export dashboard data as PDF based on view mode"""
try:
# Get view mode
view_mode = request.args.get('mode', 'submissions')
# Contributor stats
contributor_stats = db.session.query(
Submission.contributor_type,
db.func.count(Submission.id)
).group_by(Submission.contributor_type).all()
# MODE DEPENDENT: Same logic as dashboard
if view_mode == 'sentences':
# SENTENCE-LEVEL VIEW
# Get all sentences with categories joined with their parent submissions
sentences_query = db.session.query(SubmissionSentence, Submission).join(
Submission
).filter(
SubmissionSentence.category != None
).all()
# Create enhanced sentence objects with submission data
sentences = []
for sentence, submission in sentences_query:
class EnhancedSentence:
def __init__(self, sentence, submission):
self.id = sentence.id
self.text = sentence.text
self.message = sentence.text # For template compatibility
self.category = sentence.category
self.confidence = sentence.confidence
self.contributor_type = submission.contributor_type
self.timestamp = submission.timestamp
self.latitude = submission.latitude
self.longitude = submission.longitude
self.submission_id = submission.id
sentences.append(EnhancedSentence(sentence, submission))
# Category stats
category_stats = db.session.query(
SubmissionSentence.category,
db.func.count(SubmissionSentence.id)
).filter(SubmissionSentence.category != None).group_by(SubmissionSentence.category).all()
# Breakdown by contributor
breakdown = {}
for cat in CATEGORIES:
breakdown[cat] = {}
for ctype in CONTRIBUTOR_TYPES:
count = db.session.query(db.func.count(SubmissionSentence.id)).join(
Submission
).filter(
SubmissionSentence.category == cat,
Submission.contributor_type == ctype['value']
).scalar()
breakdown[cat][ctype['value']] = count
# Geotagged sentences (inherit location from parent submission)
geotagged_items = db.session.query(SubmissionSentence, Submission).join(
Submission
).filter(
Submission.latitude != None,
Submission.longitude != None,
SubmissionSentence.category != None
).all()
# Create sentence objects with location data
geotagged_data = []
for sentence, submission in geotagged_items:
class SentenceWithLocation:
def __init__(self, sentence, submission):
self.id = sentence.id
self.text = sentence.text
self.category = sentence.category
self.latitude = submission.latitude
self.longitude = submission.longitude
self.contributor_type = submission.contributor_type
self.timestamp = submission.timestamp
self.message = sentence.text
geotagged_data.append(SentenceWithLocation(sentence, submission))
# Items for contributions list
items_list = sentences
else:
# SUBMISSION-LEVEL VIEW
# Get all submissions with categories
submissions = Submission.query.filter(Submission.category != None).all()
# Category stats
category_stats = db.session.query(
Submission.category,
db.func.count(Submission.id)
).filter(Submission.category != None).group_by(Submission.category).all()
# Breakdown by contributor
breakdown = {}
for cat in CATEGORIES:
breakdown[cat] = {}
for ctype in CONTRIBUTOR_TYPES:
count = Submission.query.filter_by(
category=cat,
contributor_type=ctype['value']
).count()
breakdown[cat][ctype['value']] = count
# Geotagged submissions
geotagged_data = Submission.query.filter(
Submission.latitude != None,
Submission.longitude != None,
Submission.category != None
).all()
# Items for contributions list
items_list = submissions
# Prepare data for PDF
pdf_data = {
'submissions': items_list, # Can be sentences or submissions
'category_stats': category_stats,
'contributor_stats': contributor_stats,
'breakdown': breakdown,
'geotagged_submissions': geotagged_data,
'view_mode': view_mode,
'categories': CATEGORIES,
'contributor_types': CONTRIBUTOR_TYPES
}
# Generate PDF
buffer = io.BytesIO()
exporter = DashboardPDFExporter()
exporter.generate_pdf(buffer, pdf_data)
buffer.seek(0)
# Generate filename
mode_label = "sentence" if view_mode == 'sentences' else "submission"
filename = f"dashboard_{mode_label}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"
return send_file(
buffer,
mimetype='application/pdf',
as_attachment=True,
download_name=filename
)
except Exception as e:
logger.error(f"Error exporting dashboard PDF: {str(e)}")
flash(f'Error exporting PDF: {str(e)}', 'danger')
return redirect(url_for('admin.dashboard'))
# API Endpoints
@bp.route('/api/toggle-submissions', methods=['POST'])
@admin_required
def toggle_submissions():
current = Settings.get_setting('submission_open', 'true')
new_value = 'false' if current == 'true' else 'true'
Settings.set_setting('submission_open', new_value)
return jsonify({'success': True, 'submission_open': new_value == 'true'})
@bp.route('/api/toggle-token-generation', methods=['POST'])
@admin_required
def toggle_token_generation():
current = Settings.get_setting('token_generation_enabled', 'true')
new_value = 'false' if current == 'true' else 'true'
Settings.set_setting('token_generation_enabled', new_value)
return jsonify({'success': True, 'token_generation_enabled': new_value == 'true'})
@bp.route('/api/create-token', methods=['POST'])
@admin_required
def create_token():
data = request.json
contributor_type = data.get('type')
name = data.get('name', '').strip()
# Allow 'admin' type in addition to contributor types
valid_types = [t['value'] for t in CONTRIBUTOR_TYPES] + ['admin']
if not contributor_type or contributor_type not in valid_types:
return jsonify({'success': False, 'error': 'Invalid contributor type'}), 400
import random
import string
prefix = contributor_type[:3].upper()
random_part = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
timestamp_part = str(int(datetime.now().timestamp()))[-4:]
token_str = f"{prefix}-{random_part}{timestamp_part}"
# Default name based on type
if contributor_type == 'admin':
final_name = name if name else "Administrator"
else:
final_name = name if name else f"{contributor_type.capitalize()} User"
new_token = Token(
token=token_str,
type=contributor_type,
name=final_name
)
db.session.add(new_token)
db.session.commit()
return jsonify({'success': True, 'token': new_token.to_dict()})
@bp.route('/api/delete-token/<int:token_id>', methods=['DELETE'])
@admin_required
def delete_token(token_id):
token = Token.query.get_or_404(token_id)
# Prevent deletion of admin tokens (any token with type='admin')
if token.type == 'admin':
return jsonify({'success': False, 'error': 'Cannot delete admin token'}), 400
db.session.delete(token)
db.session.commit()
return jsonify({'success': True})
@bp.route('/api/update-category/<int:submission_id>', methods=['POST'])
@admin_required
def update_category(submission_id):
try:
submission = Submission.query.get_or_404(submission_id)
data = request.json
category = data.get('category')
confidence = data.get('confidence') # Optional: frontend can pass prediction confidence
# Store original category before change
original_category = submission.category
# Convert empty string to None
if category == '' or category == 'null':
category = None
# Validate category if not None
if category and category not in CATEGORIES:
return jsonify({'success': False, 'error': f'Invalid category: {category}'}), 400
# Create training example if admin is making a correction or confirmation
if category is not None: # Only track when assigning a category
# Check if training example already exists for this submission
existing_example = TrainingExample.query.filter_by(submission_id=submission_id).first()
if existing_example:
# Update existing example
existing_example.original_category = original_category
existing_example.corrected_category = category
existing_example.correction_timestamp = datetime.utcnow()
existing_example.confidence_score = confidence
else:
# Create new training example
training_example = TrainingExample(
submission_id=submission_id,
message=submission.message,
original_category=original_category,
corrected_category=category,
contributor_type=submission.contributor_type,
confidence_score=confidence
)
db.session.add(training_example)
# Update submission category
submission.category = category
db.session.commit()
return jsonify({'success': True, 'category': category})
except Exception as e:
db.session.rollback()
print(f"Error updating category: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/update-sentence-category/<int:sentence_id>', methods=['POST'])
@admin_required
def update_sentence_category(sentence_id):
"""Update category for a specific sentence"""
try:
sentence = SubmissionSentence.query.get_or_404(sentence_id)
data = request.json
new_category = data.get('category')
# Store original
original_category = sentence.category
# Validate category
if new_category and new_category not in CATEGORIES:
return jsonify({'success': False, 'error': f'Invalid category: {new_category}'}), 400
# Update sentence
sentence.category = new_category
# Create/update training example for this sentence
if new_category:
existing = TrainingExample.query.filter_by(sentence_id=sentence_id).first()
if existing:
existing.original_category = original_category
existing.corrected_category = new_category
existing.correction_timestamp = datetime.utcnow()
else:
training_example = TrainingExample(
sentence_id=sentence_id,
submission_id=sentence.submission_id,
message=sentence.text, # Just the sentence text
original_category=original_category,
corrected_category=new_category,
contributor_type=sentence.submission.contributor_type
)
db.session.add(training_example)
# Update parent submission's primary category (recalculate from sentences)
submission = sentence.submission
submission.category = submission.get_primary_category()
db.session.commit()
return jsonify({'success': True, 'category': new_category})
except Exception as e:
db.session.rollback()
logger.error(f"Error updating sentence category: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/toggle-flag/<int:submission_id>', methods=['POST'])
@admin_required
def toggle_flag(submission_id):
submission = Submission.query.get_or_404(submission_id)
submission.flagged_as_offensive = not submission.flagged_as_offensive
db.session.commit()
return jsonify({'success': True, 'flagged': submission.flagged_as_offensive})
@bp.route('/api/delete-submission/<int:submission_id>', methods=['DELETE'])
@admin_required
def delete_submission(submission_id):
submission = Submission.query.get_or_404(submission_id)
db.session.delete(submission)
db.session.commit()
return jsonify({'success': True})
@bp.route('/api/analyze', methods=['POST'])
@admin_required
def analyze_submissions():
import time
from sqlalchemy.exc import OperationalError
data = request.json
analyze_all = data.get('analyze_all', False)
use_sentences = data.get('use_sentences', True) # NEW: sentence-level flag (default: True)
# Get submissions to analyze
if analyze_all:
to_analyze = Submission.query.all()
else:
# For sentence-level, look for submissions without sentence analysis
if use_sentences:
to_analyze = Submission.query.filter_by(sentence_analysis_done=False).all()
else:
to_analyze = Submission.query.filter_by(category=None).all()
if not to_analyze:
return jsonify({'success': False, 'error': 'No submissions to analyze'}), 400
# Get the analyzer instance
analyzer = get_analyzer()
success_count = 0
error_count = 0
batch_size = 10 # Commit every 10 submissions to reduce lock time
for idx, submission in enumerate(to_analyze):
max_retries = 3
retry_delay = 1 # seconds
for attempt in range(max_retries):
try:
if use_sentences:
# NEW: Sentence-level analysis
sentence_results = analyzer.analyze_with_sentences(submission.message)
# Optimized DELETE: Use synchronize_session=False for better performance
SubmissionSentence.query.filter_by(submission_id=submission.id).delete(synchronize_session=False)
# Create new sentence records
for sent_idx, result in enumerate(sentence_results):
sentence = SubmissionSentence(
submission_id=submission.id,
sentence_index=sent_idx,
text=result['text'],
category=result['category'],
confidence=result.get('confidence')
)
db.session.add(sentence)
submission.sentence_analysis_done = True
# Set primary category for backward compatibility
submission.category = submission.get_primary_category()
logger.info(f"Analyzed submission {submission.id} into {len(sentence_results)} sentences")
else:
# OLD: Submission-level analysis (backward compatible)
category = analyzer.analyze(submission.message)
submission.category = category
success_count += 1
# Commit in batches to reduce lock duration
if (idx + 1) % batch_size == 0:
db.session.commit()
logger.info(f"Committed batch of {batch_size} submissions")
break # Success, exit retry loop
except OperationalError as e:
# Database locked error - retry with exponential backoff
if 'database is locked' in str(e) and attempt < max_retries - 1:
db.session.rollback()
wait_time = retry_delay * (2 ** attempt) # Exponential backoff
logger.warning(f"Database locked for submission {submission.id}, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})")
time.sleep(wait_time)
continue
else:
# Max retries reached or different error
db.session.rollback()
logger.error(f"Error analyzing submission {submission.id}: {e}")
error_count += 1
break
except Exception as e:
db.session.rollback()
logger.error(f"Error analyzing submission {submission.id}: {e}")
error_count += 1
break
# Final commit for remaining items
try:
db.session.commit()
logger.info(f"Final commit completed")
except Exception as e:
db.session.rollback()
logger.error(f"Error in final commit: {e}")
return jsonify({
'success': True,
'analyzed': success_count,
'errors': error_count,
'sentence_level': use_sentences
})
@bp.route('/export/json')
@admin_required
def export_json():
data = {
'tokens': [t.to_dict() for t in Token.query.all()],
'submissions': [s.to_dict() for s in Submission.query.all()],
'trainingExamples': [ex.to_dict() for ex in TrainingExample.query.all()],
'submissionOpen': Settings.get_setting('submission_open', 'true') == 'true',
'tokenGenerationEnabled': Settings.get_setting('token_generation_enabled', 'true') == 'true',
'exportDate': datetime.utcnow().isoformat()
}
json_str = json.dumps(data, indent=2)
buffer = io.BytesIO()
buffer.write(json_str.encode('utf-8'))
buffer.seek(0)
return send_file(
buffer,
mimetype='application/json',
as_attachment=True,
download_name=f'participatory-planning-{datetime.now().strftime("%Y-%m-%d")}.json'
)
@bp.route('/export/csv')
@admin_required
def export_csv():
submissions = Submission.query.all()
output = io.StringIO()
writer = csv.writer(output)
# Header
writer.writerow(['Timestamp', 'Contributor Type', 'Category', 'Message', 'Latitude', 'Longitude', 'Flagged'])
# Rows
for s in submissions:
writer.writerow([
s.timestamp.isoformat() if s.timestamp else '',
s.contributor_type,
s.category or 'Not analyzed',
s.message,
s.latitude or '',
s.longitude or '',
'Yes' if s.flagged_as_offensive else 'No'
])
buffer = io.BytesIO()
buffer.write(output.getvalue().encode('utf-8'))
buffer.seek(0)
return send_file(
buffer,
mimetype='text/csv',
as_attachment=True,
download_name=f'contributions-{datetime.now().strftime("%Y-%m-%d")}.csv'
)
@bp.route('/import', methods=['POST'])
@admin_required
def import_data():
if 'file' not in request.files:
return jsonify({'success': False, 'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'error': 'No file selected'}), 400
try:
data = json.load(file)
# Clear existing data (except admin token)
Submission.query.delete()
Token.query.filter(Token.type != 'admin').delete()
# Import tokens
for token_data in data.get('tokens', []):
if token_data.get('type') != 'admin': # Skip admin token as it already exists
token = Token(
token=token_data['token'],
type=token_data['type'],
name=token_data['name']
)
db.session.add(token)
# Import submissions
for sub_data in data.get('submissions', []):
location = sub_data.get('location')
submission = Submission(
message=sub_data['message'],
contributor_type=sub_data['contributorType'],
latitude=location['lat'] if location else None,
longitude=location['lng'] if location else None,
timestamp=datetime.fromisoformat(sub_data['timestamp']) if sub_data.get('timestamp') else datetime.utcnow(),
category=sub_data.get('category'),
flagged_as_offensive=sub_data.get('flaggedAsOffensive', False)
)
db.session.add(submission)
# Import training examples if present
training_examples_imported = 0
for ex_data in data.get('trainingExamples', []):
# Find corresponding submission by message (or create placeholder)
submission = Submission.query.filter_by(message=ex_data['message']).first()
if submission:
training_example = TrainingExample(
submission_id=submission.id,
message=ex_data['message'],
original_category=ex_data.get('original_category'),
corrected_category=ex_data['corrected_category'],
contributor_type=ex_data['contributor_type'],
correction_timestamp=datetime.fromisoformat(ex_data['correction_timestamp']) if ex_data.get('correction_timestamp') else datetime.utcnow(),
confidence_score=ex_data.get('confidence_score'),
used_in_training=ex_data.get('used_in_training', False)
)
db.session.add(training_example)
training_examples_imported += 1
# Import settings
Settings.set_setting('submission_open', 'true' if data.get('submissionOpen', True) else 'false')
Settings.set_setting('token_generation_enabled', 'true' if data.get('tokenGenerationEnabled', True) else 'false')
db.session.commit()
return jsonify({
'success': True,
'training_examples_imported': training_examples_imported
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/clear-all-data', methods=['POST'])
@admin_required
def clear_all_data():
"""Clear all submissions and tokens (except admin)"""
try:
# Delete all submissions
Submission.query.delete()
# Delete all tokens except admin
Token.query.filter(Token.type != 'admin').delete()
# Optionally reset settings to defaults
Settings.set_setting('submission_open', 'true')
Settings.set_setting('token_generation_enabled', 'true')
db.session.commit()
return jsonify({'success': True, 'message': 'All data cleared successfully'})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
# ============================================================================
# FINE-TUNING & TRAINING DATA ENDPOINTS
# ============================================================================
@bp.route('/training')
@admin_required
def training_dashboard():
"""Display the fine-tuning training dashboard"""
# Get training statistics
total_examples = TrainingExample.query.count()
corrections_count = TrainingExample.query.filter(
TrainingExample.original_category != TrainingExample.corrected_category
).count()
confirmations_count = total_examples - corrections_count
# Category distribution
from sqlalchemy import func
category_distribution = db.session.query(
TrainingExample.corrected_category,
func.count(TrainingExample.id)
).group_by(TrainingExample.corrected_category).all()
category_stats = {cat: 0 for cat in CATEGORIES}
for cat, count in category_distribution:
if cat in category_stats:
category_stats[cat] = count
# Get all training runs
training_runs = FineTuningRun.query.order_by(FineTuningRun.created_at.desc()).all()
# Get active model
active_model = FineTuningRun.query.filter_by(is_active_model=True).first()
# Fine-tuning settings
min_training_examples = int(Settings.get_setting('min_training_examples', '20'))
fine_tuning_enabled = Settings.get_setting('fine_tuning_enabled', 'true') == 'true'
return render_template('admin/training.html',
total_examples=total_examples,
corrections_count=corrections_count,
confirmations_count=confirmations_count,
category_stats=category_stats,
categories=CATEGORIES,
training_runs=training_runs,
active_model=active_model,
min_training_examples=min_training_examples,
fine_tuning_enabled=fine_tuning_enabled,
ready_to_train=total_examples >= min_training_examples)
@bp.route('/api/training-stats', methods=['GET'])
@admin_required
def get_training_stats():
"""Get training data statistics (API endpoint)"""
total_examples = TrainingExample.query.count()
corrections_count = TrainingExample.query.filter(
TrainingExample.original_category != TrainingExample.corrected_category
).count()
# Category distribution
from sqlalchemy import func
category_distribution = db.session.query(
TrainingExample.corrected_category,
func.count(TrainingExample.id)
).group_by(TrainingExample.corrected_category).all()
category_stats = {cat: 0 for cat in CATEGORIES}
for cat, count in category_distribution:
if cat in category_stats:
category_stats[cat] = count
# Check for data quality issues
duplicates = db.session.query(
TrainingExample.message,
func.count(TrainingExample.id)
).group_by(TrainingExample.message).having(func.count(TrainingExample.id) > 1).count()
min_examples = int(Settings.get_setting('min_training_examples', '20'))
min_per_category = min(category_stats.values()) if category_stats.values() else 0
return jsonify({
'total_examples': total_examples,
'corrections_count': corrections_count,
'confirmations_count': total_examples - corrections_count,
'category_stats': category_stats,
'duplicates_count': duplicates,
'min_examples_threshold': min_examples,
'min_examples_per_category': min_per_category,
'ready_to_train': total_examples >= min_examples and min_per_category >= 2
})
@bp.route('/api/training-examples', methods=['GET'])
@admin_required
def get_training_examples():
"""Get all training examples"""
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 50, type=int)
category_filter = request.args.get('category', 'all')
corrections_only = request.args.get('corrections_only', 'false') == 'true'
query = TrainingExample.query
if category_filter != 'all':
query = query.filter_by(corrected_category=category_filter)
if corrections_only:
query = query.filter(TrainingExample.original_category != TrainingExample.corrected_category)
query = query.order_by(TrainingExample.correction_timestamp.desc())
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
return jsonify({
'examples': [ex.to_dict() for ex in pagination.items],
'total': pagination.total,
'pages': pagination.pages,
'current_page': page
})
@bp.route('/api/training-example/<int:example_id>', methods=['DELETE'])
@admin_required
def delete_training_example(example_id):
"""Delete a training example"""
try:
example = TrainingExample.query.get_or_404(example_id)
# Don't allow deleting if already used in training
if example.used_in_training:
return jsonify({
'success': False,
'error': 'Cannot delete example already used in training run'
}), 400
db.session.delete(example)
db.session.commit()
return jsonify({'success': True})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/export-training-examples', methods=['GET'])
@admin_required
def export_training_examples():
"""Export all training examples as JSON"""
try:
# Get filter parameters
sentence_level_only = request.args.get('sentence_level_only', 'false') == 'true'
# Query examples
query = TrainingExample.query
if sentence_level_only:
query = query.filter(TrainingExample.sentence_id != None)
examples = query.all()
# Export data
export_data = {
'exported_at': datetime.utcnow().isoformat(),
'total_examples': len(examples),
'sentence_level_only': sentence_level_only,
'examples': [
{
'message': ex.message,
'original_category': ex.original_category,
'corrected_category': ex.corrected_category,
'contributor_type': ex.contributor_type,
'correction_timestamp': ex.correction_timestamp.isoformat() if ex.correction_timestamp else None,
'confidence_score': ex.confidence_score,
'is_sentence_level': ex.sentence_id is not None
}
for ex in examples
]
}
# Return as downloadable JSON file
response = jsonify(export_data)
response.headers['Content-Disposition'] = f'attachment; filename=training_examples_{datetime.utcnow().strftime("%Y%m%d_%H%M%S")}.json'
response.headers['Content-Type'] = 'application/json'
return response
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/import-training-examples', methods=['POST'])
@admin_required
def import_training_examples():
"""Import training examples from JSON file"""
try:
# Get JSON data from request
data = request.get_json()
if not data or 'examples' not in data:
return jsonify({
'success': False,
'error': 'Invalid import data. Expected JSON with "examples" array.'
}), 400
examples_data = data['examples']
imported_count = 0
skipped_count = 0
for ex_data in examples_data:
# Check if example already exists (by message and category)
existing = TrainingExample.query.filter_by(
message=ex_data['message'],
corrected_category=ex_data['corrected_category']
).first()
if existing:
skipped_count += 1
continue
# Create new training example
training_example = TrainingExample(
message=ex_data['message'],
original_category=ex_data.get('original_category'),
corrected_category=ex_data['corrected_category'],
contributor_type=ex_data.get('contributor_type', 'unknown'),
correction_timestamp=datetime.fromisoformat(ex_data['correction_timestamp']) if ex_data.get('correction_timestamp') else datetime.utcnow(),
confidence_score=ex_data.get('confidence_score'),
used_in_training=False
)
db.session.add(training_example)
imported_count += 1
db.session.commit()
return jsonify({
'success': True,
'imported': imported_count,
'skipped': skipped_count,
'total_in_file': len(examples_data)
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/clear-training-examples', methods=['POST'])
@admin_required
def clear_training_examples():
"""Clear all training examples (with options)"""
try:
data = request.get_json() or {}
# Options
clear_unused_only = data.get('unused_only', False)
sentence_level_only = data.get('sentence_level_only', False)
# Build query
query = TrainingExample.query
if clear_unused_only:
query = query.filter_by(used_in_training=False)
if sentence_level_only:
query = query.filter(TrainingExample.sentence_id != None)
# Count before delete
count = query.count()
# Delete
query.delete()
db.session.commit()
return jsonify({
'success': True,
'deleted': count,
'unused_only': clear_unused_only,
'sentence_level_only': sentence_level_only
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/import-training-dataset', methods=['POST'])
@admin_required
def import_training_dataset():
"""Import standalone training dataset (just training examples, not full session)"""
if 'file' not in request.files:
return jsonify({'success': False, 'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'error': 'No file selected'}), 400
try:
data = json.load(file)
# Support both formats: array of examples or object with 'trainingExamples' key
training_data = data if isinstance(data, list) else data.get('trainingExamples', [])
imported_count = 0
for ex_data in training_data:
# Check if training example already exists (by message)
existing = TrainingExample.query.filter_by(message=ex_data['message']).first()
if existing:
# Update existing example
existing.original_category = ex_data.get('original_category')
existing.corrected_category = ex_data['corrected_category']
existing.contributor_type = ex_data.get('contributor_type', 'other')
existing.correction_timestamp = datetime.utcnow()
existing.confidence_score = ex_data.get('confidence_score')
else:
# Create placeholder submission if needed
submission = Submission.query.filter_by(message=ex_data['message']).first()
if not submission:
# Create placeholder submission for this training example
submission = Submission(
message=ex_data['message'],
contributor_type=ex_data.get('contributor_type', 'other'),
category=ex_data.get('corrected_category'),
timestamp=datetime.utcnow()
)
db.session.add(submission)
db.session.flush() # Get submission ID
# Create new training example
training_example = TrainingExample(
submission_id=submission.id,
message=ex_data['message'],
original_category=ex_data.get('original_category'),
corrected_category=ex_data['corrected_category'],
contributor_type=ex_data.get('contributor_type', 'other'),
confidence_score=ex_data.get('confidence_score')
)
db.session.add(training_example)
imported_count += 1
db.session.commit()
return jsonify({
'success': True,
'imported_count': imported_count
})
except KeyError as e:
db.session.rollback()
return jsonify({'success': False, 'error': f'Missing required field: {str(e)}'}), 400
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
# ============================================================================
# FINE-TUNING TRAINING ORCHESTRATION ENDPOINTS
# ============================================================================
@bp.route('/api/start-fine-tuning', methods=['POST'])
@admin_required
def start_fine_tuning():
"""Start a fine-tuning training run"""
try:
config = request.json
# Validate minimum training examples
min_examples = int(Settings.get_setting('min_training_examples', '20'))
total_examples = TrainingExample.query.count()
if total_examples < min_examples:
return jsonify({
'success': False,
'error': f'Need at least {min_examples} training examples (have {total_examples})'
}), 400
# Create new training run record
training_run = FineTuningRun(
status='preparing'
)
training_run.set_config(config)
db.session.add(training_run)
db.session.commit()
run_id = training_run.id
# Start training in background thread
import threading
thread = threading.Thread(
target=_run_training_job,
args=(run_id, config)
)
thread.daemon = True
thread.start()
return jsonify({
'success': True,
'run_id': run_id,
'message': 'Training started'
})
except Exception as e:
db.session.rollback()
return jsonify({'success': False, 'error': str(e)}), 500
def _run_training_job(run_id: int, config: Dict):
"""Background job for training (runs in separate thread)"""
from app import create_app
from app.fine_tuning import BARTFineTuner
# Create new app context for this thread
app = create_app()
with app.app_context():
try:
# Get training run
run = FineTuningRun.query.get(run_id)
if not run:
print(f"Training run {run_id} not found")
return
# Update status
run.status = 'preparing'
db.session.commit()
# Get training examples (prefer sentence-level if available)
use_sentence_level = config.get('use_sentence_level_training', True)
if use_sentence_level:
# Use only sentence-level training examples
examples = TrainingExample.query.filter(TrainingExample.sentence_id != None).all()
# Fallback to submission-level if not enough sentence-level examples
if len(examples) < int(Settings.get_setting('min_training_examples', '20')):
logger.warning(f"Only {len(examples)} sentence-level examples found, including submission-level examples")
examples = TrainingExample.query.all()
else:
# Use all training examples (old behavior)
examples = TrainingExample.query.all()
training_data = [ex.to_dict() for ex in examples]
logger.info(f"Using {len(training_data)} training examples ({len([e for e in examples if e.sentence_id])} sentence-level)")
# Calculate split sizes
total = len(training_data)
run.num_training_examples = int(total * config.get('train_split', 0.7))
run.num_validation_examples = int(total * config.get('val_split', 0.15))
run.num_test_examples = total - run.num_training_examples - run.num_validation_examples
db.session.commit()
# Initialize trainer
trainer = BARTFineTuner()
# Prepare datasets
train_dataset, val_dataset, test_dataset = trainer.prepare_dataset(
training_data,
train_split=config.get('train_split', 0.7),
val_split=config.get('val_split', 0.15),
test_split=config.get('test_split', 0.15)
)
# Setup model based on training mode
training_mode = config.get('training_mode', 'head_only')
if training_mode == 'head_only':
# Head-only training (recommended for small datasets)
trainer.setup_head_only_model()
else:
# LoRA training
lora_config = {
'r': config.get('lora_rank', 16),
'lora_alpha': config.get('lora_alpha', 32),
'lora_dropout': config.get('lora_dropout', 0.1)
}
trainer.setup_lora_model(lora_config)
# Update status to training
run.status = 'training'
db.session.commit()
# Train
models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
output_dir = os.path.join(models_dir, f'run_{run_id}')
training_config = {
'learning_rate': config.get('learning_rate', 3e-4),
'num_epochs': config.get('num_epochs', 3),
'batch_size': config.get('batch_size', 8)
}
train_metrics = trainer.train(
train_dataset,
val_dataset,
output_dir,
training_config,
run_id=run_id
)
# Update status to evaluating
run.status = 'evaluating'
run.model_path = output_dir
db.session.commit()
# Evaluate on test set
test_metrics = trainer.evaluate(test_dataset, output_dir)
# Combine metrics
results = {
**train_metrics,
**test_metrics
}
run.set_results(results)
# Calculate improvement over baseline (simplified - just use test accuracy)
baseline_accuracy = 0.60 # Placeholder - could run actual baseline comparison
run.improvement_over_baseline = results['test_accuracy'] - baseline_accuracy
# Mark training examples as used
for example in examples:
example.used_in_training = True
example.training_run_id = run_id
# Complete
run.status = 'completed'
run.completed_at = datetime.utcnow()
db.session.commit()
print(f"Training run {run_id} completed successfully")
except Exception as e:
print(f"Training run {run_id} failed: {str(e)}")
run = FineTuningRun.query.get(run_id)
if run:
run.status = 'failed'
run.error_message = str(e)
db.session.commit()
@bp.route('/api/training-status/<int:run_id>', methods=['GET'])
@admin_required
def get_training_status(run_id):
"""Get status of a training run"""
run = FineTuningRun.query.get_or_404(run_id)
# Calculate progress percentage
progress = 0
if run.status == 'preparing':
progress = 10
elif run.status == 'training':
# Calculate precise progress based on steps
if run.total_steps and run.total_steps > 0 and run.current_step:
step_progress = (run.current_step / run.total_steps) * 80 # 10-90% range for training
progress = 10 + step_progress
else:
progress = 50 # Default fallback
elif run.status == 'evaluating':
progress = 90
elif run.status == 'completed':
progress = 100
elif run.status == 'failed':
progress = 0
# Get training mode from config
config = run.get_config() if hasattr(run, 'get_config') else {}
training_mode = config.get('training_mode', 'lora')
mode_label = 'classification head only' if training_mode == 'head_only' else 'LoRA adapters'
use_sentence_level = config.get('use_sentence_level_training', True)
status_messages = {
'preparing': 'Preparing training data...',
'training': f'Training model ({mode_label})...',
'evaluating': 'Evaluating model performance...',
'completed': 'Training completed successfully!',
'failed': 'Training failed'
}
response = {
'run_id': run_id,
'status': run.status,
'status_message': status_messages.get(run.status, run.status),
'progress': progress,
'details': '',
'current_epoch': run.current_epoch if hasattr(run, 'current_epoch') else None,
'total_epochs': run.total_epochs if hasattr(run, 'total_epochs') else None,
'current_step': run.current_step if hasattr(run, 'current_step') else None,
'total_steps': run.total_steps if hasattr(run, 'total_steps') else None,
'current_loss': run.current_loss if hasattr(run, 'current_loss') else None,
'progress_message': run.progress_message if hasattr(run, 'progress_message') else None
}
if run.status == 'training':
if hasattr(run, 'progress_message') and run.progress_message:
response['details'] = run.progress_message
else:
data_type = 'sentence-level' if use_sentence_level else 'submission-level'
response['details'] = f'Training on {run.num_training_examples} {data_type} examples...'
elif run.status == 'completed':
results = run.get_results()
if results:
response['results'] = results
response['details'] = f"Test accuracy: {results.get('test_accuracy', 0)*100:.1f}%"
elif run.status == 'failed':
response['error_message'] = run.error_message
return jsonify(response)
@bp.route('/api/deploy-model/<int:run_id>', methods=['POST'])
@admin_required
def deploy_model(run_id):
"""Deploy a fine-tuned model"""
try:
from app.fine_tuning import ModelManager
from app.analyzer import reload_analyzer
manager = ModelManager()
result = manager.deploy_model(run_id, db.session)
# Reload analyzer to use new model
reload_analyzer()
return jsonify({
'success': True,
**result
})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/rollback-model', methods=['POST'])
@admin_required
def rollback_model():
"""Rollback to base model"""
try:
from app.fine_tuning import ModelManager
from app.analyzer import reload_analyzer
manager = ModelManager()
result = manager.rollback_to_baseline(db.session)
# Reload analyzer to use base model
reload_analyzer()
return jsonify({
'success': True,
**result
})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/run-details/<int:run_id>', methods=['GET'])
@admin_required
def get_run_details(run_id):
"""Get detailed information about a training run"""
run = FineTuningRun.query.get_or_404(run_id)
return jsonify(run.to_dict())
@bp.route('/api/set-zero-shot-model', methods=['POST'])
@admin_required
def set_zero_shot_model():
"""Set the zero-shot model for classification"""
try:
from app.fine_tuning.model_presets import get_model_preset
from app.analyzer import reload_analyzer
data = request.get_json()
model_key = data.get('model_key')
if not model_key:
return jsonify({'success': False, 'error': 'No model key provided'}), 400
# Validate model exists and supports zero-shot
model_preset = get_model_preset(model_key)
if not model_preset.get('supports_zero_shot', False):
return jsonify({
'success': False,
'error': 'Selected model does not support zero-shot classification'
}), 400
# Save setting
Settings.set_setting('zero_shot_model', model_key)
# Reload analyzer with new model
reload_analyzer()
logger.info(f"Zero-shot model changed to: {model_preset['name']}")
return jsonify({
'success': True,
'message': f"Zero-shot model changed to {model_preset['name']}",
'model_key': model_key,
'model_name': model_preset['name']
})
except Exception as e:
logger.error(f"Error changing zero-shot model: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/get-zero-shot-model', methods=['GET'])
@admin_required
def get_zero_shot_model():
"""Get the current zero-shot model"""
try:
from app.fine_tuning.model_presets import get_model_preset
model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli')
model_preset = get_model_preset(model_key)
return jsonify({
'success': True,
'model_key': model_key,
'model_name': model_preset['name'],
'model_info': {
'size': model_preset['size'],
'speed': model_preset['speed'],
'description': model_preset['description']
}
})
except Exception as e:
logger.error(f"Error getting zero-shot model: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/delete-training-run/<int:run_id>', methods=['DELETE'])
@admin_required
def delete_training_run(run_id):
"""Delete a training run and its associated files"""
try:
run = FineTuningRun.query.get_or_404(run_id)
# Prevent deletion of active model
if run.is_active_model:
return jsonify({
'success': False,
'error': 'Cannot delete the active model. Please rollback or deploy another model first.'
}), 400
# Prevent deletion of currently training runs
if run.status == 'training':
return jsonify({
'success': False,
'error': 'Cannot delete a training run that is currently in progress.'
}), 400
# Delete model files if they exist
import shutil
if run.model_path and os.path.exists(run.model_path):
try:
shutil.rmtree(run.model_path)
logger.info(f"Deleted model files at {run.model_path}")
except Exception as e:
logger.error(f"Error deleting model files: {str(e)}")
# Continue with database deletion even if file deletion fails
# Unlink training examples from this run (don't delete the examples themselves)
for example in run.training_examples:
example.training_run_id = None
example.used_in_training = False
# Delete the training run from database
db.session.delete(run)
db.session.commit()
return jsonify({
'success': True,
'message': f'Training run #{run_id} deleted successfully'
})
except Exception as e:
db.session.rollback()
logger.error(f"Error deleting training run: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/force-delete-training-run/<int:run_id>', methods=['DELETE'])
@admin_required
def force_delete_training_run(run_id):
"""Force delete a training run, bypassing all safety checks"""
try:
run = FineTuningRun.query.get_or_404(run_id)
# If this is the active model, deactivate it first
if run.is_active_model:
run.is_active_model = False
logger.warning(f"Force deleting active model run #{run_id}")
# Delete model files if they exist
import shutil
if run.model_path and os.path.exists(run.model_path):
try:
shutil.rmtree(run.model_path)
logger.info(f"Deleted model files at {run.model_path}")
except Exception as e:
logger.error(f"Error deleting model files: {str(e)}")
# Continue with database deletion even if file deletion fails
# Unlink training examples from this run (don't delete the examples themselves)
for example in run.training_examples:
example.training_run_id = None
example.used_in_training = False
# Delete the training run from database
db.session.delete(run)
db.session.commit()
return jsonify({
'success': True,
'message': f'Training run #{run_id} force deleted successfully'
})
except Exception as e:
db.session.rollback()
logger.error(f"Error force deleting training run: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/export-model/<int:run_id>', methods=['GET'])
@admin_required
def export_model(run_id):
"""Export a trained model as a downloadable ZIP file"""
try:
import tempfile
import shutil
from datetime import datetime
run = FineTuningRun.query.get_or_404(run_id)
if run.status != 'completed':
return jsonify({
'success': False,
'error': 'Can only export completed training runs'
}), 400
if not run.model_path or not os.path.exists(run.model_path):
return jsonify({
'success': False,
'error': 'Model files not found'
}), 404
# Create temporary directory for export
temp_dir = tempfile.mkdtemp()
try:
export_name = f"model_run_{run_id}"
export_path = os.path.join(temp_dir, export_name)
# Copy model files
shutil.copytree(run.model_path, export_path)
# Create model card with metadata
config = run.get_config()
results = run.get_results()
model_card = {
'run_id': run_id,
'export_date': datetime.utcnow().isoformat(),
'created_at': run.created_at.isoformat() if run.created_at else None,
'training_mode': config.get('training_mode', 'lora'),
'base_model': 'facebook/bart-large-mnli',
'model_type': 'BART fine-tuned for text classification',
'task': 'Multi-class text classification',
'categories': ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'],
'training_config': config,
'results': results,
'improvement_over_baseline': run.improvement_over_baseline,
'num_training_examples': run.num_training_examples,
'num_validation_examples': run.num_validation_examples,
'num_test_examples': run.num_test_examples
}
with open(os.path.join(export_path, 'model_card.json'), 'w') as f:
json.dump(model_card, f, indent=2)
# Create README
readme_content = f"""# Participatory Planning Model - Run {run_id}
## Model Information
- **Export Date**: {datetime.utcnow().strftime('%Y-%m-%d %H:%M UTC')}
- **Training Mode**: {config.get('training_mode', 'lora').upper()}
- **Base Model**: facebook/bart-large-mnli
- **Task**: Multi-class text classification
## Categories
1. Vision
2. Problem
3. Objectives
4. Directives
5. Values
6. Actions
## Training Configuration
- **Learning Rate**: {config.get('learning_rate', 'N/A')}
- **Epochs**: {config.get('num_epochs', 'N/A')}
- **Batch Size**: {config.get('batch_size', 'N/A')}
- **Training Examples**: {run.num_training_examples}
- **Validation Examples**: {run.num_validation_examples}
- **Test Examples**: {run.num_test_examples}
## Performance
- **Test Accuracy**: {results.get('test_accuracy', 0)*100:.1f}%
- **Improvement over Baseline**: {run.improvement_over_baseline*100:.1f}%
## Usage
To load this model:
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("./model_run_{run_id}")
model = AutoModelForSequenceClassification.from_pretrained("./model_run_{run_id}")
```
See model_card.json for detailed metrics.
"""
with open(os.path.join(export_path, 'README.md'), 'w') as f:
f.write(readme_content)
# Create ZIP file
zip_path = os.path.join(temp_dir, f"model_run_{run_id}")
shutil.make_archive(zip_path, 'zip', temp_dir, export_name)
zip_file = f"{zip_path}.zip"
# Read ZIP file into memory before cleaning up temp dir
with open(zip_file, 'rb') as f:
zip_data = io.BytesIO(f.read())
# Clean up temp directory
shutil.rmtree(temp_dir)
# Send file from memory
zip_data.seek(0)
return send_file(
zip_data,
mimetype='application/zip',
as_attachment=True,
download_name=f'participatory_planner_model_run_{run_id}_{datetime.now().strftime("%Y%m%d")}.zip'
)
except Exception as e:
# Clean up temp dir if error occurs
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
raise e
except Exception as e:
logger.error(f"Error exporting model: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500
@bp.route('/api/import-model', methods=['POST'])
@admin_required
def import_model():
"""Import a previously exported model from ZIP file"""
try:
import tempfile
import zipfile
import shutil
if 'file' not in request.files:
return jsonify({'success': False, 'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'error': 'No file selected'}), 400
if not file.filename.endswith('.zip'):
return jsonify({'success': False, 'error': 'File must be a ZIP archive'}), 400
# Create temporary directory for extraction
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded ZIP
zip_path = os.path.join(temp_dir, 'upload.zip')
file.save(zip_path)
# Extract ZIP
extract_dir = os.path.join(temp_dir, 'extracted')
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
# Find the model directory (should be model_run_X)
contents = os.listdir(extract_dir)
if len(contents) != 1:
return jsonify({'success': False, 'error': 'Invalid model archive structure'}), 400
model_dir = os.path.join(extract_dir, contents[0])
# Validate it's a valid model
required_files = ['config.json']
model_files = ['pytorch_model.bin', 'model.safetensors'] # Either format
has_config = os.path.exists(os.path.join(model_dir, 'config.json'))
has_model = any(os.path.exists(os.path.join(model_dir, f)) for f in model_files)
if not has_config or not has_model:
return jsonify({
'success': False,
'error': 'Invalid model archive - missing required files (config.json and model weights)'
}), 400
# Read model card if available
model_info = {}
model_card_path = os.path.join(model_dir, 'model_card.json')
if os.path.exists(model_card_path):
with open(model_card_path, 'r') as f:
model_info = json.load(f)
# Create new training run record
training_run = FineTuningRun(
status='completed',
created_at=datetime.utcnow()
)
# Set config from model card if available
if 'training_config' in model_info:
training_run.set_config(model_info['training_config'])
else:
# Default config for imported models
training_run.set_config({
'training_mode': 'imported',
'imported': True,
'original_filename': file.filename
})
# Set metadata from model card
if 'num_training_examples' in model_info:
training_run.num_training_examples = model_info['num_training_examples']
if 'num_validation_examples' in model_info:
training_run.num_validation_examples = model_info['num_validation_examples']
if 'num_test_examples' in model_info:
training_run.num_test_examples = model_info['num_test_examples']
if 'results' in model_info:
training_run.set_results(model_info['results'])
if 'improvement_over_baseline' in model_info:
training_run.improvement_over_baseline = model_info['improvement_over_baseline']
training_run.completed_at = datetime.utcnow()
db.session.add(training_run)
db.session.commit()
# Copy model to models directory
models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned')
destination_path = os.path.join(models_dir, f'run_{training_run.id}')
shutil.copytree(model_dir, destination_path)
training_run.model_path = destination_path
db.session.commit()
logger.info(f"Model imported successfully as run {training_run.id}")
return jsonify({
'success': True,
'run_id': training_run.id,
'message': f'Model imported successfully as run #{training_run.id}',
'model_info': model_info
})
except zipfile.BadZipFile:
return jsonify({'success': False, 'error': 'Invalid ZIP file'}), 400
except Exception as e:
db.session.rollback()
logger.error(f"Error importing model: {str(e)}")
return jsonify({'success': False, 'error': str(e)}), 500