Spaces:
Sleeping
Sleeping
| from flask import Blueprint, render_template, request, redirect, url_for, session, flash, jsonify, send_file | |
| from app.models.models import Token, Submission, Settings, TrainingExample, FineTuningRun, SubmissionSentence | |
| from app import db | |
| from app.analyzer import get_analyzer | |
| from app.utils.pdf_export import DashboardPDFExporter | |
| from functools import wraps | |
| from typing import Dict | |
| import json | |
| import csv | |
| import io | |
| from datetime import datetime | |
| import os | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| bp = Blueprint('admin', __name__, url_prefix='/admin') | |
| CONTRIBUTOR_TYPES = [ | |
| {'value': 'government', 'label': 'Government Officer', 'description': 'Public sector representatives'}, | |
| {'value': 'community', 'label': 'Community Member', 'description': 'Local residents and community leaders'}, | |
| {'value': 'industry', 'label': 'Industry Representative', 'description': 'Business and industry stakeholders'}, | |
| {'value': 'ngo', 'label': 'NGO/Non-Profit', 'description': 'Civil society organizations'}, | |
| {'value': 'academic', 'label': 'Academic/Researcher', 'description': 'Universities and research institutions'}, | |
| {'value': 'other', 'label': 'Other Stakeholder', 'description': 'Other interested parties'} | |
| ] | |
| CATEGORIES = ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'] | |
| def admin_required(f): | |
| def decorated_function(*args, **kwargs): | |
| if 'token' not in session or session.get('type') != 'admin': | |
| return redirect(url_for('auth.login')) | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| def overview(): | |
| total_submissions = Submission.query.count() | |
| total_tokens = Token.query.filter(Token.type != 'admin').count() | |
| flagged_count = Submission.query.filter_by(flagged_as_offensive=True).count() | |
| unanalyzed_count = Submission.query.filter_by(category=None).count() | |
| submission_open = Settings.get_setting('submission_open', 'true') == 'true' | |
| token_generation_enabled = Settings.get_setting('token_generation_enabled', 'true') == 'true' | |
| analyzed = Submission.query.filter(Submission.category != None).count() > 0 | |
| return render_template('admin/overview.html', | |
| total_submissions=total_submissions, | |
| total_tokens=total_tokens, | |
| flagged_count=flagged_count, | |
| unanalyzed_count=unanalyzed_count, | |
| submission_open=submission_open, | |
| token_generation_enabled=token_generation_enabled, | |
| analyzed=analyzed) | |
| def registration(): | |
| token_generation_enabled = Settings.get_setting('token_generation_enabled', 'true') == 'true' | |
| recent_tokens = Token.query.filter(Token.type != 'admin').order_by(Token.created_at.desc()).limit(10).all() | |
| registration_url = request.host_url.rstrip('/') + url_for('auth.generate') | |
| return render_template('admin/registration.html', | |
| token_generation_enabled=token_generation_enabled, | |
| recent_tokens=recent_tokens, | |
| registration_url=registration_url) | |
| def tokens(): | |
| all_tokens = Token.query.all() | |
| return render_template('admin/tokens.html', | |
| tokens=all_tokens, | |
| contributor_types=CONTRIBUTOR_TYPES) | |
| def submissions(): | |
| category_filter = request.args.get('category', 'all') | |
| flagged_only = request.args.get('flagged', 'false') == 'true' | |
| query = Submission.query | |
| if category_filter != 'all': | |
| query = query.filter_by(category=category_filter) | |
| if flagged_only: | |
| query = query.filter_by(flagged_as_offensive=True) | |
| all_submissions = query.order_by(Submission.timestamp.desc()).all() | |
| flagged_count = Submission.query.filter_by(flagged_as_offensive=True).count() | |
| analyzed = Submission.query.filter(Submission.category != None).count() > 0 | |
| return render_template('admin/submissions.html', | |
| submissions=all_submissions, | |
| categories=CATEGORIES, | |
| category_filter=category_filter, | |
| flagged_only=flagged_only, | |
| flagged_count=flagged_count, | |
| analyzed=analyzed) | |
| def dashboard(): | |
| # Check if analyzed | |
| analyzed = Submission.query.filter(Submission.category != None).count() > 0 | |
| if not analyzed: | |
| flash('Please analyze submissions first', 'warning') | |
| return redirect(url_for('admin.overview')) | |
| # Get view mode from query param ('submissions' or 'sentences') | |
| view_mode = request.args.get('mode', 'submissions') | |
| # Contributor stats (unchanged - always submission-based) | |
| contributor_stats = db.session.query( | |
| Submission.contributor_type, | |
| db.func.count(Submission.id) | |
| ).group_by(Submission.contributor_type).all() | |
| # MODE DEPENDENT: Data changes based on sentence vs submission view | |
| if view_mode == 'sentences': | |
| # SENTENCE-LEVEL VIEW | |
| # Get all sentences with categories joined with their parent submissions | |
| sentences_query = db.session.query(SubmissionSentence, Submission).join( | |
| Submission | |
| ).filter( | |
| SubmissionSentence.category != None | |
| ).all() | |
| # Create enhanced sentence objects with submission data | |
| sentences = [] | |
| for sentence, submission in sentences_query: | |
| # Create object with both sentence and submission attributes | |
| class EnhancedSentence: | |
| def __init__(self, sentence, submission): | |
| self.id = sentence.id | |
| self.text = sentence.text | |
| self.message = sentence.text # For template compatibility | |
| self.category = sentence.category | |
| self.confidence = sentence.confidence | |
| self.contributor_type = submission.contributor_type | |
| self.timestamp = submission.timestamp | |
| self.latitude = submission.latitude | |
| self.longitude = submission.longitude | |
| self.submission_id = submission.id | |
| sentences.append(EnhancedSentence(sentence, submission)) | |
| # Category stats | |
| category_stats = db.session.query( | |
| SubmissionSentence.category, | |
| db.func.count(SubmissionSentence.id) | |
| ).filter(SubmissionSentence.category != None).group_by(SubmissionSentence.category).all() | |
| # Breakdown by contributor (via parent submission) | |
| breakdown = {} | |
| for cat in CATEGORIES: | |
| breakdown[cat] = {} | |
| for ctype in CONTRIBUTOR_TYPES: | |
| count = db.session.query(db.func.count(SubmissionSentence.id)).join( | |
| Submission | |
| ).filter( | |
| SubmissionSentence.category == cat, | |
| Submission.contributor_type == ctype['value'] | |
| ).scalar() | |
| breakdown[cat][ctype['value']] = count | |
| # Geotagged sentences (inherit location from parent submission) | |
| geotagged_items = db.session.query(SubmissionSentence, Submission).join( | |
| Submission | |
| ).filter( | |
| Submission.latitude != None, | |
| Submission.longitude != None, | |
| SubmissionSentence.category != None | |
| ).all() | |
| # Create sentence objects with location data | |
| geotagged_data = [] | |
| for sentence, submission in geotagged_items: | |
| # Create a pseudo-object that has both sentence and location data | |
| class SentenceWithLocation: | |
| def __init__(self, sentence, submission): | |
| self.id = sentence.id | |
| self.text = sentence.text | |
| self.category = sentence.category | |
| self.latitude = submission.latitude | |
| self.longitude = submission.longitude | |
| self.contributor_type = submission.contributor_type | |
| self.timestamp = submission.timestamp | |
| self.message = sentence.text # For compatibility | |
| geotagged_data.append(SentenceWithLocation(sentence, submission)) | |
| # Items for contributions list (sentences) | |
| items_by_category = sentences | |
| else: | |
| # SUBMISSION-LEVEL VIEW (default) | |
| # Get all submissions with categories | |
| submissions = Submission.query.filter(Submission.category != None).all() | |
| # Category stats | |
| category_stats = db.session.query( | |
| Submission.category, | |
| db.func.count(Submission.id) | |
| ).filter(Submission.category != None).group_by(Submission.category).all() | |
| # Breakdown by contributor type | |
| breakdown = {} | |
| for cat in CATEGORIES: | |
| breakdown[cat] = {} | |
| for ctype in CONTRIBUTOR_TYPES: | |
| count = Submission.query.filter_by( | |
| category=cat, | |
| contributor_type=ctype['value'] | |
| ).count() | |
| breakdown[cat][ctype['value']] = count | |
| # Geotagged submissions | |
| geotagged_data = Submission.query.filter( | |
| Submission.latitude != None, | |
| Submission.longitude != None, | |
| Submission.category != None | |
| ).all() | |
| # Items for contributions list (submissions) | |
| items_by_category = submissions | |
| return render_template('admin/dashboard.html', | |
| items=items_by_category, | |
| contributor_stats=contributor_stats, | |
| category_stats=category_stats, | |
| geotagged_items=geotagged_data, | |
| categories=CATEGORIES, | |
| contributor_types=CONTRIBUTOR_TYPES, | |
| breakdown=breakdown, | |
| view_mode=view_mode) | |
| def export_dashboard_pdf(): | |
| """Export dashboard data as PDF based on view mode""" | |
| try: | |
| # Get view mode | |
| view_mode = request.args.get('mode', 'submissions') | |
| # Contributor stats | |
| contributor_stats = db.session.query( | |
| Submission.contributor_type, | |
| db.func.count(Submission.id) | |
| ).group_by(Submission.contributor_type).all() | |
| # MODE DEPENDENT: Same logic as dashboard | |
| if view_mode == 'sentences': | |
| # SENTENCE-LEVEL VIEW | |
| # Get all sentences with categories joined with their parent submissions | |
| sentences_query = db.session.query(SubmissionSentence, Submission).join( | |
| Submission | |
| ).filter( | |
| SubmissionSentence.category != None | |
| ).all() | |
| # Create enhanced sentence objects with submission data | |
| sentences = [] | |
| for sentence, submission in sentences_query: | |
| class EnhancedSentence: | |
| def __init__(self, sentence, submission): | |
| self.id = sentence.id | |
| self.text = sentence.text | |
| self.message = sentence.text # For template compatibility | |
| self.category = sentence.category | |
| self.confidence = sentence.confidence | |
| self.contributor_type = submission.contributor_type | |
| self.timestamp = submission.timestamp | |
| self.latitude = submission.latitude | |
| self.longitude = submission.longitude | |
| self.submission_id = submission.id | |
| sentences.append(EnhancedSentence(sentence, submission)) | |
| # Category stats | |
| category_stats = db.session.query( | |
| SubmissionSentence.category, | |
| db.func.count(SubmissionSentence.id) | |
| ).filter(SubmissionSentence.category != None).group_by(SubmissionSentence.category).all() | |
| # Breakdown by contributor | |
| breakdown = {} | |
| for cat in CATEGORIES: | |
| breakdown[cat] = {} | |
| for ctype in CONTRIBUTOR_TYPES: | |
| count = db.session.query(db.func.count(SubmissionSentence.id)).join( | |
| Submission | |
| ).filter( | |
| SubmissionSentence.category == cat, | |
| Submission.contributor_type == ctype['value'] | |
| ).scalar() | |
| breakdown[cat][ctype['value']] = count | |
| # Geotagged sentences (inherit location from parent submission) | |
| geotagged_items = db.session.query(SubmissionSentence, Submission).join( | |
| Submission | |
| ).filter( | |
| Submission.latitude != None, | |
| Submission.longitude != None, | |
| SubmissionSentence.category != None | |
| ).all() | |
| # Create sentence objects with location data | |
| geotagged_data = [] | |
| for sentence, submission in geotagged_items: | |
| class SentenceWithLocation: | |
| def __init__(self, sentence, submission): | |
| self.id = sentence.id | |
| self.text = sentence.text | |
| self.category = sentence.category | |
| self.latitude = submission.latitude | |
| self.longitude = submission.longitude | |
| self.contributor_type = submission.contributor_type | |
| self.timestamp = submission.timestamp | |
| self.message = sentence.text | |
| geotagged_data.append(SentenceWithLocation(sentence, submission)) | |
| # Items for contributions list | |
| items_list = sentences | |
| else: | |
| # SUBMISSION-LEVEL VIEW | |
| # Get all submissions with categories | |
| submissions = Submission.query.filter(Submission.category != None).all() | |
| # Category stats | |
| category_stats = db.session.query( | |
| Submission.category, | |
| db.func.count(Submission.id) | |
| ).filter(Submission.category != None).group_by(Submission.category).all() | |
| # Breakdown by contributor | |
| breakdown = {} | |
| for cat in CATEGORIES: | |
| breakdown[cat] = {} | |
| for ctype in CONTRIBUTOR_TYPES: | |
| count = Submission.query.filter_by( | |
| category=cat, | |
| contributor_type=ctype['value'] | |
| ).count() | |
| breakdown[cat][ctype['value']] = count | |
| # Geotagged submissions | |
| geotagged_data = Submission.query.filter( | |
| Submission.latitude != None, | |
| Submission.longitude != None, | |
| Submission.category != None | |
| ).all() | |
| # Items for contributions list | |
| items_list = submissions | |
| # Prepare data for PDF | |
| pdf_data = { | |
| 'submissions': items_list, # Can be sentences or submissions | |
| 'category_stats': category_stats, | |
| 'contributor_stats': contributor_stats, | |
| 'breakdown': breakdown, | |
| 'geotagged_submissions': geotagged_data, | |
| 'view_mode': view_mode, | |
| 'categories': CATEGORIES, | |
| 'contributor_types': CONTRIBUTOR_TYPES | |
| } | |
| # Generate PDF | |
| buffer = io.BytesIO() | |
| exporter = DashboardPDFExporter() | |
| exporter.generate_pdf(buffer, pdf_data) | |
| buffer.seek(0) | |
| # Generate filename | |
| mode_label = "sentence" if view_mode == 'sentences' else "submission" | |
| filename = f"dashboard_{mode_label}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf" | |
| return send_file( | |
| buffer, | |
| mimetype='application/pdf', | |
| as_attachment=True, | |
| download_name=filename | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error exporting dashboard PDF: {str(e)}") | |
| flash(f'Error exporting PDF: {str(e)}', 'danger') | |
| return redirect(url_for('admin.dashboard')) | |
| # API Endpoints | |
| def toggle_submissions(): | |
| current = Settings.get_setting('submission_open', 'true') | |
| new_value = 'false' if current == 'true' else 'true' | |
| Settings.set_setting('submission_open', new_value) | |
| return jsonify({'success': True, 'submission_open': new_value == 'true'}) | |
| def toggle_token_generation(): | |
| current = Settings.get_setting('token_generation_enabled', 'true') | |
| new_value = 'false' if current == 'true' else 'true' | |
| Settings.set_setting('token_generation_enabled', new_value) | |
| return jsonify({'success': True, 'token_generation_enabled': new_value == 'true'}) | |
| def create_token(): | |
| data = request.json | |
| contributor_type = data.get('type') | |
| name = data.get('name', '').strip() | |
| # Allow 'admin' type in addition to contributor types | |
| valid_types = [t['value'] for t in CONTRIBUTOR_TYPES] + ['admin'] | |
| if not contributor_type or contributor_type not in valid_types: | |
| return jsonify({'success': False, 'error': 'Invalid contributor type'}), 400 | |
| import random | |
| import string | |
| prefix = contributor_type[:3].upper() | |
| random_part = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) | |
| timestamp_part = str(int(datetime.now().timestamp()))[-4:] | |
| token_str = f"{prefix}-{random_part}{timestamp_part}" | |
| # Default name based on type | |
| if contributor_type == 'admin': | |
| final_name = name if name else "Administrator" | |
| else: | |
| final_name = name if name else f"{contributor_type.capitalize()} User" | |
| new_token = Token( | |
| token=token_str, | |
| type=contributor_type, | |
| name=final_name | |
| ) | |
| db.session.add(new_token) | |
| db.session.commit() | |
| return jsonify({'success': True, 'token': new_token.to_dict()}) | |
| def delete_token(token_id): | |
| token = Token.query.get_or_404(token_id) | |
| # Prevent deletion of admin tokens (any token with type='admin') | |
| if token.type == 'admin': | |
| return jsonify({'success': False, 'error': 'Cannot delete admin token'}), 400 | |
| db.session.delete(token) | |
| db.session.commit() | |
| return jsonify({'success': True}) | |
| def update_category(submission_id): | |
| try: | |
| submission = Submission.query.get_or_404(submission_id) | |
| data = request.json | |
| category = data.get('category') | |
| confidence = data.get('confidence') # Optional: frontend can pass prediction confidence | |
| # Store original category before change | |
| original_category = submission.category | |
| # Convert empty string to None | |
| if category == '' or category == 'null': | |
| category = None | |
| # Validate category if not None | |
| if category and category not in CATEGORIES: | |
| return jsonify({'success': False, 'error': f'Invalid category: {category}'}), 400 | |
| # Create training example if admin is making a correction or confirmation | |
| if category is not None: # Only track when assigning a category | |
| # Check if training example already exists for this submission | |
| existing_example = TrainingExample.query.filter_by(submission_id=submission_id).first() | |
| if existing_example: | |
| # Update existing example | |
| existing_example.original_category = original_category | |
| existing_example.corrected_category = category | |
| existing_example.correction_timestamp = datetime.utcnow() | |
| existing_example.confidence_score = confidence | |
| else: | |
| # Create new training example | |
| training_example = TrainingExample( | |
| submission_id=submission_id, | |
| message=submission.message, | |
| original_category=original_category, | |
| corrected_category=category, | |
| contributor_type=submission.contributor_type, | |
| confidence_score=confidence | |
| ) | |
| db.session.add(training_example) | |
| # Update submission category | |
| submission.category = category | |
| db.session.commit() | |
| return jsonify({'success': True, 'category': category}) | |
| except Exception as e: | |
| db.session.rollback() | |
| print(f"Error updating category: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def update_sentence_category(sentence_id): | |
| """Update category for a specific sentence""" | |
| try: | |
| sentence = SubmissionSentence.query.get_or_404(sentence_id) | |
| data = request.json | |
| new_category = data.get('category') | |
| # Store original | |
| original_category = sentence.category | |
| # Validate category | |
| if new_category and new_category not in CATEGORIES: | |
| return jsonify({'success': False, 'error': f'Invalid category: {new_category}'}), 400 | |
| # Update sentence | |
| sentence.category = new_category | |
| # Create/update training example for this sentence | |
| if new_category: | |
| existing = TrainingExample.query.filter_by(sentence_id=sentence_id).first() | |
| if existing: | |
| existing.original_category = original_category | |
| existing.corrected_category = new_category | |
| existing.correction_timestamp = datetime.utcnow() | |
| else: | |
| training_example = TrainingExample( | |
| sentence_id=sentence_id, | |
| submission_id=sentence.submission_id, | |
| message=sentence.text, # Just the sentence text | |
| original_category=original_category, | |
| corrected_category=new_category, | |
| contributor_type=sentence.submission.contributor_type | |
| ) | |
| db.session.add(training_example) | |
| # Update parent submission's primary category (recalculate from sentences) | |
| submission = sentence.submission | |
| submission.category = submission.get_primary_category() | |
| db.session.commit() | |
| return jsonify({'success': True, 'category': new_category}) | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error updating sentence category: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def toggle_flag(submission_id): | |
| submission = Submission.query.get_or_404(submission_id) | |
| submission.flagged_as_offensive = not submission.flagged_as_offensive | |
| db.session.commit() | |
| return jsonify({'success': True, 'flagged': submission.flagged_as_offensive}) | |
| def delete_submission(submission_id): | |
| submission = Submission.query.get_or_404(submission_id) | |
| db.session.delete(submission) | |
| db.session.commit() | |
| return jsonify({'success': True}) | |
| def analyze_submissions(): | |
| import time | |
| from sqlalchemy.exc import OperationalError | |
| data = request.json | |
| analyze_all = data.get('analyze_all', False) | |
| use_sentences = data.get('use_sentences', True) # NEW: sentence-level flag (default: True) | |
| # Get submissions to analyze | |
| if analyze_all: | |
| to_analyze = Submission.query.all() | |
| else: | |
| # For sentence-level, look for submissions without sentence analysis | |
| if use_sentences: | |
| to_analyze = Submission.query.filter_by(sentence_analysis_done=False).all() | |
| else: | |
| to_analyze = Submission.query.filter_by(category=None).all() | |
| if not to_analyze: | |
| return jsonify({'success': False, 'error': 'No submissions to analyze'}), 400 | |
| # Get the analyzer instance | |
| analyzer = get_analyzer() | |
| success_count = 0 | |
| error_count = 0 | |
| batch_size = 10 # Commit every 10 submissions to reduce lock time | |
| for idx, submission in enumerate(to_analyze): | |
| max_retries = 3 | |
| retry_delay = 1 # seconds | |
| for attempt in range(max_retries): | |
| try: | |
| if use_sentences: | |
| # NEW: Sentence-level analysis | |
| sentence_results = analyzer.analyze_with_sentences(submission.message) | |
| # Optimized DELETE: Use synchronize_session=False for better performance | |
| SubmissionSentence.query.filter_by(submission_id=submission.id).delete(synchronize_session=False) | |
| # Create new sentence records | |
| for sent_idx, result in enumerate(sentence_results): | |
| sentence = SubmissionSentence( | |
| submission_id=submission.id, | |
| sentence_index=sent_idx, | |
| text=result['text'], | |
| category=result['category'], | |
| confidence=result.get('confidence') | |
| ) | |
| db.session.add(sentence) | |
| submission.sentence_analysis_done = True | |
| # Set primary category for backward compatibility | |
| submission.category = submission.get_primary_category() | |
| logger.info(f"Analyzed submission {submission.id} into {len(sentence_results)} sentences") | |
| else: | |
| # OLD: Submission-level analysis (backward compatible) | |
| category = analyzer.analyze(submission.message) | |
| submission.category = category | |
| success_count += 1 | |
| # Commit in batches to reduce lock duration | |
| if (idx + 1) % batch_size == 0: | |
| db.session.commit() | |
| logger.info(f"Committed batch of {batch_size} submissions") | |
| break # Success, exit retry loop | |
| except OperationalError as e: | |
| # Database locked error - retry with exponential backoff | |
| if 'database is locked' in str(e) and attempt < max_retries - 1: | |
| db.session.rollback() | |
| wait_time = retry_delay * (2 ** attempt) # Exponential backoff | |
| logger.warning(f"Database locked for submission {submission.id}, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})") | |
| time.sleep(wait_time) | |
| continue | |
| else: | |
| # Max retries reached or different error | |
| db.session.rollback() | |
| logger.error(f"Error analyzing submission {submission.id}: {e}") | |
| error_count += 1 | |
| break | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error analyzing submission {submission.id}: {e}") | |
| error_count += 1 | |
| break | |
| # Final commit for remaining items | |
| try: | |
| db.session.commit() | |
| logger.info(f"Final commit completed") | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error in final commit: {e}") | |
| return jsonify({ | |
| 'success': True, | |
| 'analyzed': success_count, | |
| 'errors': error_count, | |
| 'sentence_level': use_sentences | |
| }) | |
| def export_json(): | |
| data = { | |
| 'tokens': [t.to_dict() for t in Token.query.all()], | |
| 'submissions': [s.to_dict() for s in Submission.query.all()], | |
| 'trainingExamples': [ex.to_dict() for ex in TrainingExample.query.all()], | |
| 'submissionOpen': Settings.get_setting('submission_open', 'true') == 'true', | |
| 'tokenGenerationEnabled': Settings.get_setting('token_generation_enabled', 'true') == 'true', | |
| 'exportDate': datetime.utcnow().isoformat() | |
| } | |
| json_str = json.dumps(data, indent=2) | |
| buffer = io.BytesIO() | |
| buffer.write(json_str.encode('utf-8')) | |
| buffer.seek(0) | |
| return send_file( | |
| buffer, | |
| mimetype='application/json', | |
| as_attachment=True, | |
| download_name=f'participatory-planning-{datetime.now().strftime("%Y-%m-%d")}.json' | |
| ) | |
| def export_csv(): | |
| submissions = Submission.query.all() | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| # Header | |
| writer.writerow(['Timestamp', 'Contributor Type', 'Category', 'Message', 'Latitude', 'Longitude', 'Flagged']) | |
| # Rows | |
| for s in submissions: | |
| writer.writerow([ | |
| s.timestamp.isoformat() if s.timestamp else '', | |
| s.contributor_type, | |
| s.category or 'Not analyzed', | |
| s.message, | |
| s.latitude or '', | |
| s.longitude or '', | |
| 'Yes' if s.flagged_as_offensive else 'No' | |
| ]) | |
| buffer = io.BytesIO() | |
| buffer.write(output.getvalue().encode('utf-8')) | |
| buffer.seek(0) | |
| return send_file( | |
| buffer, | |
| mimetype='text/csv', | |
| as_attachment=True, | |
| download_name=f'contributions-{datetime.now().strftime("%Y-%m-%d")}.csv' | |
| ) | |
| def import_data(): | |
| if 'file' not in request.files: | |
| return jsonify({'success': False, 'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No file selected'}), 400 | |
| try: | |
| data = json.load(file) | |
| # Clear existing data (except admin token) | |
| Submission.query.delete() | |
| Token.query.filter(Token.type != 'admin').delete() | |
| # Import tokens | |
| for token_data in data.get('tokens', []): | |
| if token_data.get('type') != 'admin': # Skip admin token as it already exists | |
| token = Token( | |
| token=token_data['token'], | |
| type=token_data['type'], | |
| name=token_data['name'] | |
| ) | |
| db.session.add(token) | |
| # Import submissions | |
| for sub_data in data.get('submissions', []): | |
| location = sub_data.get('location') | |
| submission = Submission( | |
| message=sub_data['message'], | |
| contributor_type=sub_data['contributorType'], | |
| latitude=location['lat'] if location else None, | |
| longitude=location['lng'] if location else None, | |
| timestamp=datetime.fromisoformat(sub_data['timestamp']) if sub_data.get('timestamp') else datetime.utcnow(), | |
| category=sub_data.get('category'), | |
| flagged_as_offensive=sub_data.get('flaggedAsOffensive', False) | |
| ) | |
| db.session.add(submission) | |
| # Import training examples if present | |
| training_examples_imported = 0 | |
| for ex_data in data.get('trainingExamples', []): | |
| # Find corresponding submission by message (or create placeholder) | |
| submission = Submission.query.filter_by(message=ex_data['message']).first() | |
| if submission: | |
| training_example = TrainingExample( | |
| submission_id=submission.id, | |
| message=ex_data['message'], | |
| original_category=ex_data.get('original_category'), | |
| corrected_category=ex_data['corrected_category'], | |
| contributor_type=ex_data['contributor_type'], | |
| correction_timestamp=datetime.fromisoformat(ex_data['correction_timestamp']) if ex_data.get('correction_timestamp') else datetime.utcnow(), | |
| confidence_score=ex_data.get('confidence_score'), | |
| used_in_training=ex_data.get('used_in_training', False) | |
| ) | |
| db.session.add(training_example) | |
| training_examples_imported += 1 | |
| # Import settings | |
| Settings.set_setting('submission_open', 'true' if data.get('submissionOpen', True) else 'false') | |
| Settings.set_setting('token_generation_enabled', 'true' if data.get('tokenGenerationEnabled', True) else 'false') | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'training_examples_imported': training_examples_imported | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def clear_all_data(): | |
| """Clear all submissions and tokens (except admin)""" | |
| try: | |
| # Delete all submissions | |
| Submission.query.delete() | |
| # Delete all tokens except admin | |
| Token.query.filter(Token.type != 'admin').delete() | |
| # Optionally reset settings to defaults | |
| Settings.set_setting('submission_open', 'true') | |
| Settings.set_setting('token_generation_enabled', 'true') | |
| db.session.commit() | |
| return jsonify({'success': True, 'message': 'All data cleared successfully'}) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| # ============================================================================ | |
| # FINE-TUNING & TRAINING DATA ENDPOINTS | |
| # ============================================================================ | |
| def training_dashboard(): | |
| """Display the fine-tuning training dashboard""" | |
| # Get training statistics | |
| total_examples = TrainingExample.query.count() | |
| corrections_count = TrainingExample.query.filter( | |
| TrainingExample.original_category != TrainingExample.corrected_category | |
| ).count() | |
| confirmations_count = total_examples - corrections_count | |
| # Category distribution | |
| from sqlalchemy import func | |
| category_distribution = db.session.query( | |
| TrainingExample.corrected_category, | |
| func.count(TrainingExample.id) | |
| ).group_by(TrainingExample.corrected_category).all() | |
| category_stats = {cat: 0 for cat in CATEGORIES} | |
| for cat, count in category_distribution: | |
| if cat in category_stats: | |
| category_stats[cat] = count | |
| # Get all training runs | |
| training_runs = FineTuningRun.query.order_by(FineTuningRun.created_at.desc()).all() | |
| # Get active model | |
| active_model = FineTuningRun.query.filter_by(is_active_model=True).first() | |
| # Fine-tuning settings | |
| min_training_examples = int(Settings.get_setting('min_training_examples', '20')) | |
| fine_tuning_enabled = Settings.get_setting('fine_tuning_enabled', 'true') == 'true' | |
| return render_template('admin/training.html', | |
| total_examples=total_examples, | |
| corrections_count=corrections_count, | |
| confirmations_count=confirmations_count, | |
| category_stats=category_stats, | |
| categories=CATEGORIES, | |
| training_runs=training_runs, | |
| active_model=active_model, | |
| min_training_examples=min_training_examples, | |
| fine_tuning_enabled=fine_tuning_enabled, | |
| ready_to_train=total_examples >= min_training_examples) | |
| def get_training_stats(): | |
| """Get training data statistics (API endpoint)""" | |
| total_examples = TrainingExample.query.count() | |
| corrections_count = TrainingExample.query.filter( | |
| TrainingExample.original_category != TrainingExample.corrected_category | |
| ).count() | |
| # Category distribution | |
| from sqlalchemy import func | |
| category_distribution = db.session.query( | |
| TrainingExample.corrected_category, | |
| func.count(TrainingExample.id) | |
| ).group_by(TrainingExample.corrected_category).all() | |
| category_stats = {cat: 0 for cat in CATEGORIES} | |
| for cat, count in category_distribution: | |
| if cat in category_stats: | |
| category_stats[cat] = count | |
| # Check for data quality issues | |
| duplicates = db.session.query( | |
| TrainingExample.message, | |
| func.count(TrainingExample.id) | |
| ).group_by(TrainingExample.message).having(func.count(TrainingExample.id) > 1).count() | |
| min_examples = int(Settings.get_setting('min_training_examples', '20')) | |
| min_per_category = min(category_stats.values()) if category_stats.values() else 0 | |
| return jsonify({ | |
| 'total_examples': total_examples, | |
| 'corrections_count': corrections_count, | |
| 'confirmations_count': total_examples - corrections_count, | |
| 'category_stats': category_stats, | |
| 'duplicates_count': duplicates, | |
| 'min_examples_threshold': min_examples, | |
| 'min_examples_per_category': min_per_category, | |
| 'ready_to_train': total_examples >= min_examples and min_per_category >= 2 | |
| }) | |
| def get_training_examples(): | |
| """Get all training examples""" | |
| page = request.args.get('page', 1, type=int) | |
| per_page = request.args.get('per_page', 50, type=int) | |
| category_filter = request.args.get('category', 'all') | |
| corrections_only = request.args.get('corrections_only', 'false') == 'true' | |
| query = TrainingExample.query | |
| if category_filter != 'all': | |
| query = query.filter_by(corrected_category=category_filter) | |
| if corrections_only: | |
| query = query.filter(TrainingExample.original_category != TrainingExample.corrected_category) | |
| query = query.order_by(TrainingExample.correction_timestamp.desc()) | |
| pagination = query.paginate(page=page, per_page=per_page, error_out=False) | |
| return jsonify({ | |
| 'examples': [ex.to_dict() for ex in pagination.items], | |
| 'total': pagination.total, | |
| 'pages': pagination.pages, | |
| 'current_page': page | |
| }) | |
| def delete_training_example(example_id): | |
| """Delete a training example""" | |
| try: | |
| example = TrainingExample.query.get_or_404(example_id) | |
| # Don't allow deleting if already used in training | |
| if example.used_in_training: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Cannot delete example already used in training run' | |
| }), 400 | |
| db.session.delete(example) | |
| db.session.commit() | |
| return jsonify({'success': True}) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def export_training_examples(): | |
| """Export all training examples as JSON""" | |
| try: | |
| # Get filter parameters | |
| sentence_level_only = request.args.get('sentence_level_only', 'false') == 'true' | |
| # Query examples | |
| query = TrainingExample.query | |
| if sentence_level_only: | |
| query = query.filter(TrainingExample.sentence_id != None) | |
| examples = query.all() | |
| # Export data | |
| export_data = { | |
| 'exported_at': datetime.utcnow().isoformat(), | |
| 'total_examples': len(examples), | |
| 'sentence_level_only': sentence_level_only, | |
| 'examples': [ | |
| { | |
| 'message': ex.message, | |
| 'original_category': ex.original_category, | |
| 'corrected_category': ex.corrected_category, | |
| 'contributor_type': ex.contributor_type, | |
| 'correction_timestamp': ex.correction_timestamp.isoformat() if ex.correction_timestamp else None, | |
| 'confidence_score': ex.confidence_score, | |
| 'is_sentence_level': ex.sentence_id is not None | |
| } | |
| for ex in examples | |
| ] | |
| } | |
| # Return as downloadable JSON file | |
| response = jsonify(export_data) | |
| response.headers['Content-Disposition'] = f'attachment; filename=training_examples_{datetime.utcnow().strftime("%Y%m%d_%H%M%S")}.json' | |
| response.headers['Content-Type'] = 'application/json' | |
| return response | |
| except Exception as e: | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def import_training_examples(): | |
| """Import training examples from JSON file""" | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if not data or 'examples' not in data: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Invalid import data. Expected JSON with "examples" array.' | |
| }), 400 | |
| examples_data = data['examples'] | |
| imported_count = 0 | |
| skipped_count = 0 | |
| for ex_data in examples_data: | |
| # Check if example already exists (by message and category) | |
| existing = TrainingExample.query.filter_by( | |
| message=ex_data['message'], | |
| corrected_category=ex_data['corrected_category'] | |
| ).first() | |
| if existing: | |
| skipped_count += 1 | |
| continue | |
| # Create new training example | |
| training_example = TrainingExample( | |
| message=ex_data['message'], | |
| original_category=ex_data.get('original_category'), | |
| corrected_category=ex_data['corrected_category'], | |
| contributor_type=ex_data.get('contributor_type', 'unknown'), | |
| correction_timestamp=datetime.fromisoformat(ex_data['correction_timestamp']) if ex_data.get('correction_timestamp') else datetime.utcnow(), | |
| confidence_score=ex_data.get('confidence_score'), | |
| used_in_training=False | |
| ) | |
| db.session.add(training_example) | |
| imported_count += 1 | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'imported': imported_count, | |
| 'skipped': skipped_count, | |
| 'total_in_file': len(examples_data) | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def clear_training_examples(): | |
| """Clear all training examples (with options)""" | |
| try: | |
| data = request.get_json() or {} | |
| # Options | |
| clear_unused_only = data.get('unused_only', False) | |
| sentence_level_only = data.get('sentence_level_only', False) | |
| # Build query | |
| query = TrainingExample.query | |
| if clear_unused_only: | |
| query = query.filter_by(used_in_training=False) | |
| if sentence_level_only: | |
| query = query.filter(TrainingExample.sentence_id != None) | |
| # Count before delete | |
| count = query.count() | |
| # Delete | |
| query.delete() | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'deleted': count, | |
| 'unused_only': clear_unused_only, | |
| 'sentence_level_only': sentence_level_only | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def import_training_dataset(): | |
| """Import standalone training dataset (just training examples, not full session)""" | |
| if 'file' not in request.files: | |
| return jsonify({'success': False, 'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No file selected'}), 400 | |
| try: | |
| data = json.load(file) | |
| # Support both formats: array of examples or object with 'trainingExamples' key | |
| training_data = data if isinstance(data, list) else data.get('trainingExamples', []) | |
| imported_count = 0 | |
| for ex_data in training_data: | |
| # Check if training example already exists (by message) | |
| existing = TrainingExample.query.filter_by(message=ex_data['message']).first() | |
| if existing: | |
| # Update existing example | |
| existing.original_category = ex_data.get('original_category') | |
| existing.corrected_category = ex_data['corrected_category'] | |
| existing.contributor_type = ex_data.get('contributor_type', 'other') | |
| existing.correction_timestamp = datetime.utcnow() | |
| existing.confidence_score = ex_data.get('confidence_score') | |
| else: | |
| # Create placeholder submission if needed | |
| submission = Submission.query.filter_by(message=ex_data['message']).first() | |
| if not submission: | |
| # Create placeholder submission for this training example | |
| submission = Submission( | |
| message=ex_data['message'], | |
| contributor_type=ex_data.get('contributor_type', 'other'), | |
| category=ex_data.get('corrected_category'), | |
| timestamp=datetime.utcnow() | |
| ) | |
| db.session.add(submission) | |
| db.session.flush() # Get submission ID | |
| # Create new training example | |
| training_example = TrainingExample( | |
| submission_id=submission.id, | |
| message=ex_data['message'], | |
| original_category=ex_data.get('original_category'), | |
| corrected_category=ex_data['corrected_category'], | |
| contributor_type=ex_data.get('contributor_type', 'other'), | |
| confidence_score=ex_data.get('confidence_score') | |
| ) | |
| db.session.add(training_example) | |
| imported_count += 1 | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'imported_count': imported_count | |
| }) | |
| except KeyError as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': f'Missing required field: {str(e)}'}), 400 | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| # ============================================================================ | |
| # FINE-TUNING TRAINING ORCHESTRATION ENDPOINTS | |
| # ============================================================================ | |
| def start_fine_tuning(): | |
| """Start a fine-tuning training run""" | |
| try: | |
| config = request.json | |
| # Validate minimum training examples | |
| min_examples = int(Settings.get_setting('min_training_examples', '20')) | |
| total_examples = TrainingExample.query.count() | |
| if total_examples < min_examples: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Need at least {min_examples} training examples (have {total_examples})' | |
| }), 400 | |
| # Create new training run record | |
| training_run = FineTuningRun( | |
| status='preparing' | |
| ) | |
| training_run.set_config(config) | |
| db.session.add(training_run) | |
| db.session.commit() | |
| run_id = training_run.id | |
| # Start training in background thread | |
| import threading | |
| thread = threading.Thread( | |
| target=_run_training_job, | |
| args=(run_id, config) | |
| ) | |
| thread.daemon = True | |
| thread.start() | |
| return jsonify({ | |
| 'success': True, | |
| 'run_id': run_id, | |
| 'message': 'Training started' | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def _run_training_job(run_id: int, config: Dict): | |
| """Background job for training (runs in separate thread)""" | |
| from app import create_app | |
| from app.fine_tuning import BARTFineTuner | |
| # Create new app context for this thread | |
| app = create_app() | |
| with app.app_context(): | |
| try: | |
| # Get training run | |
| run = FineTuningRun.query.get(run_id) | |
| if not run: | |
| print(f"Training run {run_id} not found") | |
| return | |
| # Update status | |
| run.status = 'preparing' | |
| db.session.commit() | |
| # Get training examples (prefer sentence-level if available) | |
| use_sentence_level = config.get('use_sentence_level_training', True) | |
| if use_sentence_level: | |
| # Use only sentence-level training examples | |
| examples = TrainingExample.query.filter(TrainingExample.sentence_id != None).all() | |
| # Fallback to submission-level if not enough sentence-level examples | |
| if len(examples) < int(Settings.get_setting('min_training_examples', '20')): | |
| logger.warning(f"Only {len(examples)} sentence-level examples found, including submission-level examples") | |
| examples = TrainingExample.query.all() | |
| else: | |
| # Use all training examples (old behavior) | |
| examples = TrainingExample.query.all() | |
| training_data = [ex.to_dict() for ex in examples] | |
| logger.info(f"Using {len(training_data)} training examples ({len([e for e in examples if e.sentence_id])} sentence-level)") | |
| # Calculate split sizes | |
| total = len(training_data) | |
| run.num_training_examples = int(total * config.get('train_split', 0.7)) | |
| run.num_validation_examples = int(total * config.get('val_split', 0.15)) | |
| run.num_test_examples = total - run.num_training_examples - run.num_validation_examples | |
| db.session.commit() | |
| # Initialize trainer | |
| trainer = BARTFineTuner() | |
| # Prepare datasets | |
| train_dataset, val_dataset, test_dataset = trainer.prepare_dataset( | |
| training_data, | |
| train_split=config.get('train_split', 0.7), | |
| val_split=config.get('val_split', 0.15), | |
| test_split=config.get('test_split', 0.15) | |
| ) | |
| # Setup model based on training mode | |
| training_mode = config.get('training_mode', 'head_only') | |
| if training_mode == 'head_only': | |
| # Head-only training (recommended for small datasets) | |
| trainer.setup_head_only_model() | |
| else: | |
| # LoRA training | |
| lora_config = { | |
| 'r': config.get('lora_rank', 16), | |
| 'lora_alpha': config.get('lora_alpha', 32), | |
| 'lora_dropout': config.get('lora_dropout', 0.1) | |
| } | |
| trainer.setup_lora_model(lora_config) | |
| # Update status to training | |
| run.status = 'training' | |
| db.session.commit() | |
| # Train | |
| models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned') | |
| output_dir = os.path.join(models_dir, f'run_{run_id}') | |
| training_config = { | |
| 'learning_rate': config.get('learning_rate', 3e-4), | |
| 'num_epochs': config.get('num_epochs', 3), | |
| 'batch_size': config.get('batch_size', 8) | |
| } | |
| train_metrics = trainer.train( | |
| train_dataset, | |
| val_dataset, | |
| output_dir, | |
| training_config, | |
| run_id=run_id | |
| ) | |
| # Update status to evaluating | |
| run.status = 'evaluating' | |
| run.model_path = output_dir | |
| db.session.commit() | |
| # Evaluate on test set | |
| test_metrics = trainer.evaluate(test_dataset, output_dir) | |
| # Combine metrics | |
| results = { | |
| **train_metrics, | |
| **test_metrics | |
| } | |
| run.set_results(results) | |
| # Calculate improvement over baseline (simplified - just use test accuracy) | |
| baseline_accuracy = 0.60 # Placeholder - could run actual baseline comparison | |
| run.improvement_over_baseline = results['test_accuracy'] - baseline_accuracy | |
| # Mark training examples as used | |
| for example in examples: | |
| example.used_in_training = True | |
| example.training_run_id = run_id | |
| # Complete | |
| run.status = 'completed' | |
| run.completed_at = datetime.utcnow() | |
| db.session.commit() | |
| print(f"Training run {run_id} completed successfully") | |
| except Exception as e: | |
| print(f"Training run {run_id} failed: {str(e)}") | |
| run = FineTuningRun.query.get(run_id) | |
| if run: | |
| run.status = 'failed' | |
| run.error_message = str(e) | |
| db.session.commit() | |
| def get_training_status(run_id): | |
| """Get status of a training run""" | |
| run = FineTuningRun.query.get_or_404(run_id) | |
| # Calculate progress percentage | |
| progress = 0 | |
| if run.status == 'preparing': | |
| progress = 10 | |
| elif run.status == 'training': | |
| # Calculate precise progress based on steps | |
| if run.total_steps and run.total_steps > 0 and run.current_step: | |
| step_progress = (run.current_step / run.total_steps) * 80 # 10-90% range for training | |
| progress = 10 + step_progress | |
| else: | |
| progress = 50 # Default fallback | |
| elif run.status == 'evaluating': | |
| progress = 90 | |
| elif run.status == 'completed': | |
| progress = 100 | |
| elif run.status == 'failed': | |
| progress = 0 | |
| # Get training mode from config | |
| config = run.get_config() if hasattr(run, 'get_config') else {} | |
| training_mode = config.get('training_mode', 'lora') | |
| mode_label = 'classification head only' if training_mode == 'head_only' else 'LoRA adapters' | |
| use_sentence_level = config.get('use_sentence_level_training', True) | |
| status_messages = { | |
| 'preparing': 'Preparing training data...', | |
| 'training': f'Training model ({mode_label})...', | |
| 'evaluating': 'Evaluating model performance...', | |
| 'completed': 'Training completed successfully!', | |
| 'failed': 'Training failed' | |
| } | |
| response = { | |
| 'run_id': run_id, | |
| 'status': run.status, | |
| 'status_message': status_messages.get(run.status, run.status), | |
| 'progress': progress, | |
| 'details': '', | |
| 'current_epoch': run.current_epoch if hasattr(run, 'current_epoch') else None, | |
| 'total_epochs': run.total_epochs if hasattr(run, 'total_epochs') else None, | |
| 'current_step': run.current_step if hasattr(run, 'current_step') else None, | |
| 'total_steps': run.total_steps if hasattr(run, 'total_steps') else None, | |
| 'current_loss': run.current_loss if hasattr(run, 'current_loss') else None, | |
| 'progress_message': run.progress_message if hasattr(run, 'progress_message') else None | |
| } | |
| if run.status == 'training': | |
| if hasattr(run, 'progress_message') and run.progress_message: | |
| response['details'] = run.progress_message | |
| else: | |
| data_type = 'sentence-level' if use_sentence_level else 'submission-level' | |
| response['details'] = f'Training on {run.num_training_examples} {data_type} examples...' | |
| elif run.status == 'completed': | |
| results = run.get_results() | |
| if results: | |
| response['results'] = results | |
| response['details'] = f"Test accuracy: {results.get('test_accuracy', 0)*100:.1f}%" | |
| elif run.status == 'failed': | |
| response['error_message'] = run.error_message | |
| return jsonify(response) | |
| def deploy_model(run_id): | |
| """Deploy a fine-tuned model""" | |
| try: | |
| from app.fine_tuning import ModelManager | |
| from app.analyzer import reload_analyzer | |
| manager = ModelManager() | |
| result = manager.deploy_model(run_id, db.session) | |
| # Reload analyzer to use new model | |
| reload_analyzer() | |
| return jsonify({ | |
| 'success': True, | |
| **result | |
| }) | |
| except Exception as e: | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def rollback_model(): | |
| """Rollback to base model""" | |
| try: | |
| from app.fine_tuning import ModelManager | |
| from app.analyzer import reload_analyzer | |
| manager = ModelManager() | |
| result = manager.rollback_to_baseline(db.session) | |
| # Reload analyzer to use base model | |
| reload_analyzer() | |
| return jsonify({ | |
| 'success': True, | |
| **result | |
| }) | |
| except Exception as e: | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def get_run_details(run_id): | |
| """Get detailed information about a training run""" | |
| run = FineTuningRun.query.get_or_404(run_id) | |
| return jsonify(run.to_dict()) | |
| def set_zero_shot_model(): | |
| """Set the zero-shot model for classification""" | |
| try: | |
| from app.fine_tuning.model_presets import get_model_preset | |
| from app.analyzer import reload_analyzer | |
| data = request.get_json() | |
| model_key = data.get('model_key') | |
| if not model_key: | |
| return jsonify({'success': False, 'error': 'No model key provided'}), 400 | |
| # Validate model exists and supports zero-shot | |
| model_preset = get_model_preset(model_key) | |
| if not model_preset.get('supports_zero_shot', False): | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Selected model does not support zero-shot classification' | |
| }), 400 | |
| # Save setting | |
| Settings.set_setting('zero_shot_model', model_key) | |
| # Reload analyzer with new model | |
| reload_analyzer() | |
| logger.info(f"Zero-shot model changed to: {model_preset['name']}") | |
| return jsonify({ | |
| 'success': True, | |
| 'message': f"Zero-shot model changed to {model_preset['name']}", | |
| 'model_key': model_key, | |
| 'model_name': model_preset['name'] | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error changing zero-shot model: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def get_zero_shot_model(): | |
| """Get the current zero-shot model""" | |
| try: | |
| from app.fine_tuning.model_presets import get_model_preset | |
| model_key = Settings.get_setting('zero_shot_model', 'bart-large-mnli') | |
| model_preset = get_model_preset(model_key) | |
| return jsonify({ | |
| 'success': True, | |
| 'model_key': model_key, | |
| 'model_name': model_preset['name'], | |
| 'model_info': { | |
| 'size': model_preset['size'], | |
| 'speed': model_preset['speed'], | |
| 'description': model_preset['description'] | |
| } | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error getting zero-shot model: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def delete_training_run(run_id): | |
| """Delete a training run and its associated files""" | |
| try: | |
| run = FineTuningRun.query.get_or_404(run_id) | |
| # Prevent deletion of active model | |
| if run.is_active_model: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Cannot delete the active model. Please rollback or deploy another model first.' | |
| }), 400 | |
| # Prevent deletion of currently training runs | |
| if run.status == 'training': | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Cannot delete a training run that is currently in progress.' | |
| }), 400 | |
| # Delete model files if they exist | |
| import shutil | |
| if run.model_path and os.path.exists(run.model_path): | |
| try: | |
| shutil.rmtree(run.model_path) | |
| logger.info(f"Deleted model files at {run.model_path}") | |
| except Exception as e: | |
| logger.error(f"Error deleting model files: {str(e)}") | |
| # Continue with database deletion even if file deletion fails | |
| # Unlink training examples from this run (don't delete the examples themselves) | |
| for example in run.training_examples: | |
| example.training_run_id = None | |
| example.used_in_training = False | |
| # Delete the training run from database | |
| db.session.delete(run) | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'message': f'Training run #{run_id} deleted successfully' | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error deleting training run: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def force_delete_training_run(run_id): | |
| """Force delete a training run, bypassing all safety checks""" | |
| try: | |
| run = FineTuningRun.query.get_or_404(run_id) | |
| # If this is the active model, deactivate it first | |
| if run.is_active_model: | |
| run.is_active_model = False | |
| logger.warning(f"Force deleting active model run #{run_id}") | |
| # Delete model files if they exist | |
| import shutil | |
| if run.model_path and os.path.exists(run.model_path): | |
| try: | |
| shutil.rmtree(run.model_path) | |
| logger.info(f"Deleted model files at {run.model_path}") | |
| except Exception as e: | |
| logger.error(f"Error deleting model files: {str(e)}") | |
| # Continue with database deletion even if file deletion fails | |
| # Unlink training examples from this run (don't delete the examples themselves) | |
| for example in run.training_examples: | |
| example.training_run_id = None | |
| example.used_in_training = False | |
| # Delete the training run from database | |
| db.session.delete(run) | |
| db.session.commit() | |
| return jsonify({ | |
| 'success': True, | |
| 'message': f'Training run #{run_id} force deleted successfully' | |
| }) | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error force deleting training run: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def export_model(run_id): | |
| """Export a trained model as a downloadable ZIP file""" | |
| try: | |
| import tempfile | |
| import shutil | |
| from datetime import datetime | |
| run = FineTuningRun.query.get_or_404(run_id) | |
| if run.status != 'completed': | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Can only export completed training runs' | |
| }), 400 | |
| if not run.model_path or not os.path.exists(run.model_path): | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Model files not found' | |
| }), 404 | |
| # Create temporary directory for export | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| export_name = f"model_run_{run_id}" | |
| export_path = os.path.join(temp_dir, export_name) | |
| # Copy model files | |
| shutil.copytree(run.model_path, export_path) | |
| # Create model card with metadata | |
| config = run.get_config() | |
| results = run.get_results() | |
| model_card = { | |
| 'run_id': run_id, | |
| 'export_date': datetime.utcnow().isoformat(), | |
| 'created_at': run.created_at.isoformat() if run.created_at else None, | |
| 'training_mode': config.get('training_mode', 'lora'), | |
| 'base_model': 'facebook/bart-large-mnli', | |
| 'model_type': 'BART fine-tuned for text classification', | |
| 'task': 'Multi-class text classification', | |
| 'categories': ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'], | |
| 'training_config': config, | |
| 'results': results, | |
| 'improvement_over_baseline': run.improvement_over_baseline, | |
| 'num_training_examples': run.num_training_examples, | |
| 'num_validation_examples': run.num_validation_examples, | |
| 'num_test_examples': run.num_test_examples | |
| } | |
| with open(os.path.join(export_path, 'model_card.json'), 'w') as f: | |
| json.dump(model_card, f, indent=2) | |
| # Create README | |
| readme_content = f"""# Participatory Planning Model - Run {run_id} | |
| ## Model Information | |
| - **Export Date**: {datetime.utcnow().strftime('%Y-%m-%d %H:%M UTC')} | |
| - **Training Mode**: {config.get('training_mode', 'lora').upper()} | |
| - **Base Model**: facebook/bart-large-mnli | |
| - **Task**: Multi-class text classification | |
| ## Categories | |
| 1. Vision | |
| 2. Problem | |
| 3. Objectives | |
| 4. Directives | |
| 5. Values | |
| 6. Actions | |
| ## Training Configuration | |
| - **Learning Rate**: {config.get('learning_rate', 'N/A')} | |
| - **Epochs**: {config.get('num_epochs', 'N/A')} | |
| - **Batch Size**: {config.get('batch_size', 'N/A')} | |
| - **Training Examples**: {run.num_training_examples} | |
| - **Validation Examples**: {run.num_validation_examples} | |
| - **Test Examples**: {run.num_test_examples} | |
| ## Performance | |
| - **Test Accuracy**: {results.get('test_accuracy', 0)*100:.1f}% | |
| - **Improvement over Baseline**: {run.improvement_over_baseline*100:.1f}% | |
| ## Usage | |
| To load this model: | |
| ```python | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| tokenizer = AutoTokenizer.from_pretrained("./model_run_{run_id}") | |
| model = AutoModelForSequenceClassification.from_pretrained("./model_run_{run_id}") | |
| ``` | |
| See model_card.json for detailed metrics. | |
| """ | |
| with open(os.path.join(export_path, 'README.md'), 'w') as f: | |
| f.write(readme_content) | |
| # Create ZIP file | |
| zip_path = os.path.join(temp_dir, f"model_run_{run_id}") | |
| shutil.make_archive(zip_path, 'zip', temp_dir, export_name) | |
| zip_file = f"{zip_path}.zip" | |
| # Read ZIP file into memory before cleaning up temp dir | |
| with open(zip_file, 'rb') as f: | |
| zip_data = io.BytesIO(f.read()) | |
| # Clean up temp directory | |
| shutil.rmtree(temp_dir) | |
| # Send file from memory | |
| zip_data.seek(0) | |
| return send_file( | |
| zip_data, | |
| mimetype='application/zip', | |
| as_attachment=True, | |
| download_name=f'participatory_planner_model_run_{run_id}_{datetime.now().strftime("%Y%m%d")}.zip' | |
| ) | |
| except Exception as e: | |
| # Clean up temp dir if error occurs | |
| if os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir) | |
| raise e | |
| except Exception as e: | |
| logger.error(f"Error exporting model: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |
| def import_model(): | |
| """Import a previously exported model from ZIP file""" | |
| try: | |
| import tempfile | |
| import zipfile | |
| import shutil | |
| if 'file' not in request.files: | |
| return jsonify({'success': False, 'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'success': False, 'error': 'No file selected'}), 400 | |
| if not file.filename.endswith('.zip'): | |
| return jsonify({'success': False, 'error': 'File must be a ZIP archive'}), 400 | |
| # Create temporary directory for extraction | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Save uploaded ZIP | |
| zip_path = os.path.join(temp_dir, 'upload.zip') | |
| file.save(zip_path) | |
| # Extract ZIP | |
| extract_dir = os.path.join(temp_dir, 'extracted') | |
| os.makedirs(extract_dir, exist_ok=True) | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_dir) | |
| # Find the model directory (should be model_run_X) | |
| contents = os.listdir(extract_dir) | |
| if len(contents) != 1: | |
| return jsonify({'success': False, 'error': 'Invalid model archive structure'}), 400 | |
| model_dir = os.path.join(extract_dir, contents[0]) | |
| # Validate it's a valid model | |
| required_files = ['config.json'] | |
| model_files = ['pytorch_model.bin', 'model.safetensors'] # Either format | |
| has_config = os.path.exists(os.path.join(model_dir, 'config.json')) | |
| has_model = any(os.path.exists(os.path.join(model_dir, f)) for f in model_files) | |
| if not has_config or not has_model: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Invalid model archive - missing required files (config.json and model weights)' | |
| }), 400 | |
| # Read model card if available | |
| model_info = {} | |
| model_card_path = os.path.join(model_dir, 'model_card.json') | |
| if os.path.exists(model_card_path): | |
| with open(model_card_path, 'r') as f: | |
| model_info = json.load(f) | |
| # Create new training run record | |
| training_run = FineTuningRun( | |
| status='completed', | |
| created_at=datetime.utcnow() | |
| ) | |
| # Set config from model card if available | |
| if 'training_config' in model_info: | |
| training_run.set_config(model_info['training_config']) | |
| else: | |
| # Default config for imported models | |
| training_run.set_config({ | |
| 'training_mode': 'imported', | |
| 'imported': True, | |
| 'original_filename': file.filename | |
| }) | |
| # Set metadata from model card | |
| if 'num_training_examples' in model_info: | |
| training_run.num_training_examples = model_info['num_training_examples'] | |
| if 'num_validation_examples' in model_info: | |
| training_run.num_validation_examples = model_info['num_validation_examples'] | |
| if 'num_test_examples' in model_info: | |
| training_run.num_test_examples = model_info['num_test_examples'] | |
| if 'results' in model_info: | |
| training_run.set_results(model_info['results']) | |
| if 'improvement_over_baseline' in model_info: | |
| training_run.improvement_over_baseline = model_info['improvement_over_baseline'] | |
| training_run.completed_at = datetime.utcnow() | |
| db.session.add(training_run) | |
| db.session.commit() | |
| # Copy model to models directory | |
| models_dir = os.getenv('MODELS_DIR', '/data/models/finetuned') | |
| destination_path = os.path.join(models_dir, f'run_{training_run.id}') | |
| shutil.copytree(model_dir, destination_path) | |
| training_run.model_path = destination_path | |
| db.session.commit() | |
| logger.info(f"Model imported successfully as run {training_run.id}") | |
| return jsonify({ | |
| 'success': True, | |
| 'run_id': training_run.id, | |
| 'message': f'Model imported successfully as run #{training_run.id}', | |
| 'model_info': model_info | |
| }) | |
| except zipfile.BadZipFile: | |
| return jsonify({'success': False, 'error': 'Invalid ZIP file'}), 400 | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Error importing model: {str(e)}") | |
| return jsonify({'success': False, 'error': str(e)}), 500 | |