Spaces:
Sleeping
Sleeping
| """ | |
| Authentication and authorization utilities | |
| """ | |
| import jwt | |
| import bcrypt | |
| import json | |
| import os | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Dict, Any | |
| from functools import wraps | |
| # Import Flask components only when available | |
| try: | |
| from flask import request, jsonify, session, redirect, url_for | |
| FLASK_AVAILABLE = True | |
| except ImportError: | |
| FLASK_AVAILABLE = False | |
| request = None | |
| jsonify = None | |
| session = None | |
| redirect = None | |
| url_for = None | |
| class AuthManager: | |
| def __init__(self, secret_key: str = None): | |
| self.secret_key = secret_key or os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') | |
| self.users_file = 'data/users.json' | |
| self.active_sessions = {} # Track active sessions for security | |
| self.session_file = 'data/active_sessions.json' | |
| self.ensure_users_file() | |
| def ensure_users_file(self): | |
| """Ensure users file exists""" | |
| os.makedirs('data', exist_ok=True) | |
| if not os.path.exists(self.users_file): | |
| with open(self.users_file, 'w') as f: | |
| json.dump({}, f) | |
| def hash_password(self, password: str) -> str: | |
| """Hash password with bcrypt""" | |
| return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') | |
| def verify_password(self, password: str, hashed: str) -> bool: | |
| """Verify password against hash""" | |
| return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8')) | |
| def load_users(self) -> Dict[str, Any]: | |
| """Load users from file""" | |
| try: | |
| with open(self.users_file, 'r') as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| def save_users(self, users: Dict[str, Any]): | |
| """Save users to file""" | |
| with open(self.users_file, 'w') as f: | |
| json.dump(users, f, indent=2) | |
| def create_user(self, username: str, email: str, password: str) -> Dict[str, Any]: | |
| """Create a new user""" | |
| users = self.load_users() | |
| if username in users: | |
| return {'success': False, 'error': 'Username already exists'} | |
| # Check if email already exists | |
| for user_data in users.values(): | |
| if user_data.get('email') == email: | |
| return {'success': False, 'error': 'Email already registered'} | |
| # Create user | |
| user_id = f"user_{len(users) + 1}" | |
| users[username] = { | |
| 'user_id': user_id, | |
| 'email': email, | |
| 'password_hash': self.hash_password(password), | |
| 'created_at': datetime.now().isoformat(), | |
| 'is_active': True | |
| } | |
| self.save_users(users) | |
| return {'success': True, 'user_id': user_id} | |
| def authenticate_user(self, username: str, password: str) -> Dict[str, Any]: | |
| """Authenticate user credentials""" | |
| users = self.load_users() | |
| if username not in users: | |
| return {'success': False, 'error': 'Invalid username or password'} | |
| user = users[username] | |
| if not self.verify_password(password, user['password_hash']): | |
| return {'success': False, 'error': 'Invalid username or password'} | |
| if not user.get('is_active', True): | |
| return {'success': False, 'error': 'Account is disabled'} | |
| # Generate JWT token with shorter expiration for security | |
| token = jwt.encode({ | |
| 'user_id': user['user_id'], | |
| 'username': username, | |
| 'exp': datetime.utcnow() + timedelta(hours=8) # 8 hours instead of 7 days | |
| }, self.secret_key, algorithm='HS256') | |
| # Track active session | |
| self.add_active_session(user['user_id'], token) | |
| return { | |
| 'success': True, | |
| 'token': token, | |
| 'user_id': user['user_id'], | |
| 'username': username | |
| } | |
| def verify_token(self, token: str) -> Optional[Dict[str, Any]]: | |
| """Verify JWT token and check active session""" | |
| try: | |
| payload = jwt.decode(token, self.secret_key, algorithms=['HS256']) | |
| user_id = payload.get('user_id') | |
| # Check if session is still active | |
| if not self.is_session_active(user_id, token): | |
| return None | |
| # Update session activity | |
| self.update_session_activity(user_id) | |
| return payload | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| def get_current_user(self, request_obj) -> Optional[Dict[str, Any]]: | |
| """Get current user from request""" | |
| if not FLASK_AVAILABLE or not request_obj: | |
| return None | |
| # Try Authorization header first | |
| auth_header = request_obj.headers.get('Authorization') | |
| if auth_header and auth_header.startswith('Bearer '): | |
| token = auth_header.split(' ')[1] | |
| return self.verify_token(token) | |
| # Try session | |
| if session: | |
| token = session.get('auth_token') | |
| if token: | |
| return self.verify_token(token) | |
| return None | |
| def create_default_admin(self) -> Dict[str, Any]: | |
| """Create default admin user if it doesn't exist""" | |
| users = self.load_users() | |
| admin_username = "admin" | |
| admin_user_id = "admin_user" | |
| # Check if admin already exists (by username or user_id) | |
| if admin_username in users: | |
| return {'success': True, 'message': 'Admin user already exists'} | |
| # Check if user_id already exists | |
| for user_data in users.values(): | |
| if user_data.get('user_id') == admin_user_id: | |
| return {'success': True, 'message': 'Admin user ID already exists'} | |
| # Create admin user | |
| users[admin_username] = { | |
| 'user_id': admin_user_id, | |
| 'email': 'admin@researchmate.local', | |
| 'password_hash': self.hash_password('admin123'), # Default password | |
| 'created_at': datetime.now().isoformat(), | |
| 'is_active': True, | |
| 'is_admin': True | |
| } | |
| self.save_users(users) | |
| return { | |
| 'success': True, | |
| 'message': 'Default admin user created', | |
| 'username': admin_username, | |
| 'password': 'admin123', | |
| 'note': 'Please change the default password after first login' | |
| } | |
| def load_active_sessions(self) -> Dict[str, Any]: | |
| """Load active sessions from file""" | |
| try: | |
| if os.path.exists(self.session_file): | |
| with open(self.session_file, 'r') as f: | |
| return json.load(f) | |
| except: | |
| pass | |
| return {} | |
| def save_active_sessions(self, sessions: Dict[str, Any]): | |
| """Save active sessions to file""" | |
| try: | |
| with open(self.session_file, 'w') as f: | |
| json.dump(sessions, f, indent=2) | |
| except: | |
| pass | |
| def add_active_session(self, user_id: str, token: str): | |
| """Add an active session""" | |
| sessions = self.load_active_sessions() | |
| sessions[user_id] = { | |
| 'token': token, | |
| 'created_at': datetime.now().isoformat(), | |
| 'last_activity': datetime.now().isoformat() | |
| } | |
| self.save_active_sessions(sessions) | |
| def remove_active_session(self, user_id: str): | |
| """Remove an active session""" | |
| sessions = self.load_active_sessions() | |
| if user_id in sessions: | |
| del sessions[user_id] | |
| self.save_active_sessions(sessions) | |
| def is_session_active(self, user_id: str, token: str) -> bool: | |
| """Check if a session is active""" | |
| sessions = self.load_active_sessions() | |
| if user_id not in sessions: | |
| return False | |
| session = sessions[user_id] | |
| if session.get('token') != token: | |
| return False | |
| # Check if session is expired (8 hours) | |
| created_at = datetime.fromisoformat(session['created_at']) | |
| if datetime.now() - created_at > timedelta(hours=8): | |
| self.remove_active_session(user_id) | |
| return False | |
| return True | |
| def logout_user(self, user_id: str): | |
| """Logout user and invalidate session""" | |
| self.remove_active_session(user_id) | |
| return {'success': True, 'message': 'Logged out successfully'} | |
| def cleanup_expired_sessions(self): | |
| """Clean up expired sessions""" | |
| sessions = self.load_active_sessions() | |
| current_time = datetime.now() | |
| expired_sessions = [] | |
| for user_id, session in sessions.items(): | |
| created_at = datetime.fromisoformat(session['created_at']) | |
| if current_time - created_at > timedelta(hours=8): | |
| expired_sessions.append(user_id) | |
| for user_id in expired_sessions: | |
| del sessions[user_id] | |
| if expired_sessions: | |
| self.save_active_sessions(sessions) | |
| return len(expired_sessions) | |
| def update_session_activity(self, user_id: str): | |
| """Update last activity time for a session""" | |
| sessions = self.load_active_sessions() | |
| if user_id in sessions: | |
| sessions[user_id]['last_activity'] = datetime.now().isoformat() | |
| self.save_active_sessions(sessions) | |
| # Global auth manager | |
| auth_manager = AuthManager() | |
| def require_auth(f): | |
| """Decorator to require authentication""" | |
| def decorated_function(*args, **kwargs): | |
| if not FLASK_AVAILABLE: | |
| return f(*args, **kwargs) | |
| user = auth_manager.get_current_user(request) | |
| if not user: | |
| if request.is_json: | |
| return jsonify({'success': False, 'error': 'Authentication required'}), 401 | |
| else: | |
| return redirect(url_for('login')) | |
| return f(*args, **kwargs) | |
| return decorated_function | |
| def get_current_user() -> Optional[Dict[str, Any]]: | |
| """Get current authenticated user""" | |
| if not FLASK_AVAILABLE: | |
| return None | |
| return auth_manager.get_current_user(request) |