Spaces:
Paused
Paused
| import os, io, uuid, json, wave, math | |
| import aiohttp | |
| import aiofiles | |
| import asyncio | |
| import sqlite3 | |
| import bcrypt | |
| import tempfile | |
| import shutil | |
| import subprocess # to use "ffmpeg" in order to use "ffmpeg-python" library for video to audio conversion | |
| from fastapi import FastAPI, Request, Query, Form, UploadFile, File, HTTPException, Depends | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from datetime import datetime, timedelta | |
| from jose import JWTError, jwt | |
| from pydantic import BaseModel, EmailStr | |
| from pydantic import confloat | |
| # from pydub import AudioSegment # rely on "ffmpeg" package from the system and "ffmpeg-python" library | |
| from pydub.silence import split_on_silence | |
| import soundfile as sf | |
| import numpy as np | |
| from scipy import signal | |
| from scipy.io import wavfile | |
| from typing import Optional, List, Dict, Union | |
| from dotenv import load_dotenv | |
| from sqlalchemy import create_engine, Column, Integer, String, Boolean, ForeignKey | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session | |
| load_dotenv() | |
| # Constants | |
| MAX_DURATION = 3600 * 2 # 2 hours | |
| #CHUNK_SIZE = 1024 * 1024 # 1MB chunks | |
| MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB | |
| system_prompt = """ | |
| You are a helpful assistant. Your task is to correct | |
| any spelling discrepancies in the transcribed text. Only add necessary | |
| punctuation such as periods, commas, and capitalization, and use only the | |
| context provided. | |
| """ | |
| names = [] | |
| extra_prompt = f""" | |
| Make sure that the names of : {','.join(names)} | |
| the following products are spelled correctly | |
| """ | |
| # Authentication & Security | |
| SECRET_KEY = os.getenv("SECRET_KEY", None) # Change in production | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 60 | |
| # Directory definitions | |
| UPLOAD_DIR = "uploads" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # Database setup | |
| DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() # Default to SQLite if not specified | |
| DATABASE_URL = os.getenv("DATABASE_URL") | |
| EMPTY_DB = os.getenv("EMPTY_DB", "false").lower() == "true" # Get EMPTY_DB flag | |
| # Configure engine based on database type | |
| if DB_TYPE == "postgresql": | |
| if DATABASE_URL and DATABASE_URL.startswith("postgres://"): | |
| DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1) | |
| # PostgreSQL configuration for production | |
| engine = create_engine( | |
| DATABASE_URL, | |
| pool_size=5, | |
| max_overflow=10, | |
| pool_timeout=30, | |
| pool_recycle=1800, # Recycle connections after 30 minutes | |
| ) | |
| elif DB_TYPE == "sqlite": | |
| # SQLite configuration for development | |
| engine = create_engine( | |
| DATABASE_URL, | |
| connect_args={"check_same_thread": False} | |
| ) | |
| else: | |
| raise ValueError("Unsupported database type. Use 'sqlite' or 'postgresql'.") | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| # SQLAlchemy Models | |
| class UserModel(Base): | |
| __tablename__ = "users" | |
| id = Column(Integer, primary_key=True, index=True) | |
| username = Column(String, unique=True, index=True) | |
| email = Column(String, unique=True, index=True) | |
| hashed_password = Column(String) | |
| disabled = Column(Boolean, default=True) # Default to disabled | |
| is_admin = Column(Boolean, default=False) # Add admin field | |
| # Create tables | |
| Base.metadata.create_all(bind=engine) | |
| # Dependency to get DB session | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| # Add these models | |
| class UserBase(BaseModel): | |
| username: str | |
| email: EmailStr | |
| class UserCreate(UserBase): | |
| password: str | |
| class User(UserBase): | |
| id: int | |
| disabled: bool = True # Default to disabled | |
| is_admin: bool = False # Add admin field | |
| class Config: | |
| from_attributes = True | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| class TokenData(BaseModel): | |
| username: str | None = None | |
| # Add this with other SQLAlchemy models | |
| class TranscriptionModel(Base): | |
| __tablename__ = "transcriptions" | |
| id = Column(String, primary_key=True, index=True) | |
| name = Column(String) | |
| text = Column(String) | |
| segments = Column(String) # Store JSON as string | |
| audio_file = Column(String) | |
| user_id = Column(Integer, ForeignKey("users.id")) | |
| created_at = Column(String, default=lambda: datetime.utcnow().isoformat()) | |
| # Add this with other Pydantic models | |
| class TranscriptionBase(BaseModel): | |
| name: str | |
| text: str | |
| segments: Optional[List[Dict[str, Union[str, float]]]] = [] | |
| audio_file: Optional[str] = None | |
| class TranscriptionCreate(TranscriptionBase): | |
| pass | |
| class Transcription(TranscriptionBase): | |
| id: str | |
| user_id: int | |
| created_at: str | |
| class Config: | |
| from_attributes = True | |
| # Add this with other SQLAlchemy models | |
| class UserCreditModel(Base): | |
| __tablename__ = "user_credits" | |
| id = Column(Integer, primary_key=True, index=True) | |
| user_id = Column(Integer, ForeignKey("users.id")) | |
| minutes_used = Column(Integer, default=0) | |
| minutes_quota = Column(Integer, default=10) # Default to 10 minutes | |
| last_updated = Column(String, default=lambda: datetime.utcnow().isoformat()) | |
| class UserCredit(BaseModel): | |
| user_id: int | |
| minutes_used: int | |
| minutes_quota: int | |
| last_updated: str | |
| class Config: | |
| from_attributes = True | |
| # Add these utilities | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| # App | |
| app = FastAPI() | |
| CORS_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:8000,http://127.0.0.1:8000").split(",") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=CORS_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # Create necessary directories | |
| os.makedirs("uploads", exist_ok=True) | |
| # os.makedirs("transcriptions", exist_ok=True) # we use DB to store | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads") | |
| # Initialize database | |
| def init_db(): | |
| try: | |
| print("Creating the database...") | |
| # If EMPTY_DB is True, drop all tables and recreate them | |
| if EMPTY_DB: | |
| print("EMPTY_DB flag is set to True. Dropping all tables...") | |
| Base.metadata.drop_all(bind=engine) | |
| print("All tables dropped successfully.") | |
| Base.metadata.create_all(bind=engine) | |
| print("Database initialized successfully") | |
| except Exception as e: | |
| print(f"Database: {DB_TYPE} initialization error: {str(e)}") | |
| def verify_password(plain_password, hashed_password): | |
| try: | |
| # Ensure both password and hash are in correct format | |
| if isinstance(hashed_password, str): | |
| hashed_password = hashed_password.encode('utf-8') | |
| if isinstance(plain_password, str): | |
| plain_password = plain_password.encode('utf-8') | |
| return bcrypt.checkpw(plain_password, hashed_password) | |
| except Exception as e: | |
| print(f"Password verification error: {str(e)}") | |
| return False | |
| def get_password_hash(password): | |
| if isinstance(password, str): | |
| password = password.encode('utf-8') | |
| return bcrypt.hashpw(password, bcrypt.gensalt()).decode('utf-8') | |
| def get_user(db: Session, username: str): | |
| user = db.query(UserModel).filter(UserModel.username == username).first() | |
| if user: | |
| return User( | |
| id=user.id, | |
| username=user.username, | |
| email=user.email, | |
| disabled=user.disabled, | |
| is_admin=user.is_admin | |
| ) | |
| return None | |
| def authenticate_user(db: Session, username: str, password: str): | |
| user = db.query(UserModel).filter(UserModel.username == username).first() | |
| if not user: | |
| return False | |
| if not verify_password(password, user.hashed_password): | |
| return False | |
| return User( | |
| id=user.id, | |
| username=user.username, | |
| email=user.email, | |
| disabled=user.disabled, | |
| is_admin=user.is_admin | |
| ) | |
| async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=401, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| token_data = TokenData(username=username) | |
| except JWTError: | |
| raise credentials_exception | |
| user = get_user(db, username=token_data.username) | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| async def get_current_active_admin(current_user: User = Depends(get_current_user)): | |
| if current_user.disabled: | |
| raise HTTPException(status_code=400, detail="Inactive user") | |
| if not current_user.is_admin: | |
| raise HTTPException(status_code=403, detail="Not enough permissions") | |
| return current_user | |
| async def get_list_users( | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| users = db.query(UserModel).all() | |
| return [ | |
| { | |
| "id": user.id, | |
| "username": user.username, | |
| "email": user.email, | |
| "disabled": user.disabled, | |
| "is_admin": user.is_admin | |
| } | |
| for user in users | |
| ] | |
| async def enable_user( | |
| user_id: int, | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Prevent disabling the last admin | |
| if user.is_admin and user.id == current_user.id: | |
| admin_count = db.query(UserModel).filter(UserModel.is_admin == True, UserModel.disabled == False).count() | |
| if admin_count <= 1: | |
| raise HTTPException(status_code=400, detail="Cannot disable the last active admin") | |
| user.disabled = False | |
| db.commit() | |
| return {"message": f"User {user.username} has been enabled"} | |
| async def disable_user( | |
| user_id: int, | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Prevent disabling yourself if you're an admin | |
| if user.is_admin and user.id == current_user.id: | |
| raise HTTPException(status_code=400, detail="Admins cannot disable themselves") | |
| user.disabled = True | |
| db.commit() | |
| return {"message": f"User {user.username} has been disabled"} | |
| async def delete_user( | |
| user_id: int, | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| # Prevent deleting yourself | |
| if user_id == current_user.id: | |
| raise HTTPException(status_code=400, detail="Users cannot delete their own accounts") | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Get username before deletion for the response | |
| username = user.username | |
| # Delete the user | |
| db.delete(user) | |
| db.commit() | |
| return {"message": f"User {username} has been deleted"} | |
| def create_access_token(data: dict, expires_delta: timedelta | None = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| ### FUNCTIONS | |
| def convert_to_flac(input_file: str, output_file: str): | |
| """ | |
| Convert audio to 16KHz mono FLAC format. | |
| """ | |
| try: | |
| # Read the audio file | |
| data, samplerate = sf.read(input_file) | |
| # Convert to mono if stereo | |
| if len(data.shape) > 1: | |
| data = np.mean(data, axis=1) | |
| # Resample to 16KHz if needed | |
| if samplerate != 16000: | |
| ratio = 16000 / samplerate | |
| data = signal.resample(data, int(len(data) * ratio)) | |
| samplerate = 16000 | |
| # Save as FLAC | |
| sf.write(output_file, data, samplerate, format='FLAC') | |
| except Exception as e: | |
| raise HTTPException(500, f"Error converting to FLAC: {str(e)}") | |
| ######## ROUTES | |
| # Near the top where CORS is configured | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For production, replace with specific domains | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # Also ensure your login form data is properly sent | |
| #, response_model=Token) | |
| async def login_for_access_token( | |
| form_data: OAuth2PasswordRequestForm = Depends(), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| user = authenticate_user(db, form_data.username, form_data.password) | |
| if not user: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| if user.disabled: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="User is disabled" | |
| ) | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| except Exception as e: | |
| print(f"Login error: {str(e)}") # Add logging for debugging | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Internal server error: {str(e)}" | |
| ) | |
| async def signup(user: UserCreate, db: Session = Depends(get_db)): | |
| try: | |
| # Check if username or email already exists | |
| db_user = db.query(UserModel).filter( | |
| (UserModel.username == user.username) | (UserModel.email == user.email) | |
| ).first() | |
| if db_user: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Username or email already exists" | |
| ) | |
| # Check if this is the first user | |
| user_count = db.query(UserModel).count() | |
| is_first_user = user_count == 0 | |
| # Create new user | |
| hashed_password = get_password_hash(user.password) | |
| db_user = UserModel( | |
| username=user.username, | |
| email=user.email, | |
| hashed_password=hashed_password, | |
| disabled=not is_first_user, # First user is enabled, others disabled | |
| is_admin=is_first_user # First user is admin | |
| ) | |
| db.add(db_user) | |
| db.commit() | |
| db.refresh(db_user) | |
| return {"message": "User created successfully", "is_admin": is_first_user, "disabled": not is_first_user} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| db.rollback() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred: {str(e)}" | |
| ) | |
| # Helper function to split audio into chunks | |
| def split_audio_chunks( | |
| file_path: str, | |
| chunk_size: int = 600, # 10 minutes in seconds | |
| overlap: int = 5 # 5 seconds overlap | |
| ) -> List[Dict[str, Union[str, float]]]: | |
| """ | |
| Split audio into chunks with overlap and save as FLAC files. | |
| """ | |
| try: | |
| # Check file size | |
| file_size = os.path.getsize(file_path) | |
| if file_size > MAX_FILE_SIZE: | |
| print(f"Warning: File size ({file_size / (1024 * 1024):.2f}MB) exceeds maximum size limit ({MAX_FILE_SIZE / (1024 * 1024)}MB)") | |
| else: | |
| print(f"File size: {file_size / (1024 * 1024):.2f}MB") | |
| # Check duration | |
| # with wave.open(file_path, 'rb') as wav_file: | |
| # duration = wav_file.getnframes() / wav_file.getframerate() | |
| # if duration > MAX_DURATION: | |
| # print(f"Warning: File duration ({duration:.2f}s) exceeds maximum duration limit ({MAX_DURATION:.2f}s)") | |
| # else: | |
| # print(f"File duration: {duration:.2f}s") | |
| # Read audio file using soundfile (supports multiple formats) | |
| audio_data, sample_rate = sf.read(file_path) | |
| # Convert stereo to mono if needed | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Calculate chunk size in samples | |
| chunk_size_samples = chunk_size * sample_rate | |
| overlap_samples = overlap * sample_rate | |
| chunks = [] | |
| start = 0 | |
| while start < len(audio_data): | |
| end = min(start + chunk_size_samples, len(audio_data)) | |
| chunk_data = audio_data[start:end] | |
| # Save chunk as FLAC | |
| chunk_path = f"{file_path}_chunk_{len(chunks)}.flac" | |
| sf.write(chunk_path, chunk_data, sample_rate, format='FLAC') | |
| chunks.append({ | |
| "path": chunk_path, | |
| "start": start / sample_rate, # Start time in seconds | |
| "end": end / sample_rate # End time in seconds | |
| }) | |
| # Move start position with overlap | |
| start += chunk_size_samples - overlap_samples | |
| return chunks | |
| except Exception as e: | |
| raise HTTPException(500, f"Error splitting audio: {str(e)}") | |
| # Process individual chunk | |
| async def process_chunk( | |
| chunk_path: str, | |
| model: str, | |
| language: str, | |
| temperature: float, | |
| response_format: str, | |
| prompt: str, | |
| chunk_offset: float, | |
| semaphore: asyncio.Semaphore | |
| ): | |
| """ | |
| Process a single audio chunk with the transcription API. | |
| """ | |
| async with semaphore: | |
| try: | |
| async with aiofiles.open(chunk_path, "rb") as f: | |
| file_data = await f.read() | |
| form_data = aiohttp.FormData() | |
| form_data.add_field('file', file_data, filename=chunk_path) | |
| form_data.add_field('model', model) | |
| form_data.add_field('temperature', str(temperature)) | |
| form_data.add_field('response_format', response_format) | |
| if language != "auto": | |
| form_data.add_field('language', language) | |
| if prompt: | |
| form_data.add_field('prompt', prompt) | |
| print(f"Processing chunk {chunk_path}... prompt: {prompt}") | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| f"{os.getenv('OPENAI_BASE_URL')}/audio/transcriptions", | |
| headers={'Authorization': f"Bearer {os.getenv('OPENAI_API_KEY')}"}, | |
| data=form_data | |
| ) as response: | |
| result = await response.json() if response_format != "text" else await response.text() | |
| # Adjust timestamps for chunks | |
| if response_format == "verbose_json": | |
| for segment in result.get("segments", []): | |
| segment["start"] += chunk_offset | |
| segment["end"] += chunk_offset | |
| return result | |
| except Exception as e: | |
| # Add detailed error logging | |
| print(f"Error processing chunk {chunk_path}:") | |
| print(f"Type: {type(e).__name__}") | |
| print(f"Message: {str(e)}") | |
| return {"error": str(e), "chunk": chunk_path} | |
| # Combine results from all chunks | |
| def combine_results(results: list, response_format: str): | |
| """Combine results from all chunks with validation""" | |
| final = {"text": "", "segments": []} | |
| error_chunks = [] | |
| for idx, result in enumerate(results): | |
| if "error" in result: | |
| error_chunks.append({ | |
| "chunk": idx, | |
| "error": result["error"] | |
| }) | |
| continue | |
| # Validate response structure | |
| if response_format == "verbose_json": | |
| if not isinstance(result.get("segments"), list): | |
| error_chunks.append({ | |
| "chunk": idx, | |
| "error": "Invalid segments format" | |
| }) | |
| continue | |
| # Validate segment timestamps | |
| for segment in result.get("segments", []): | |
| if not all(key in segment for key in ['start', 'end', 'text']): | |
| error_chunks.append({ | |
| "chunk": idx, | |
| "error": f"Invalid segment format in chunk {idx}" | |
| }) | |
| break | |
| final["segments"].extend(result.get("segments", [])) | |
| # Add debug information to response | |
| final["metadata"] = { | |
| "total_chunks": len(results), | |
| "failed_chunks": len(error_chunks), | |
| "errors": error_chunks | |
| } | |
| return final | |
| def get_user_upload_dir(user_id): | |
| """Get the upload directory for a specific user""" | |
| user_dir = os.path.join(UPLOAD_DIR, str(user_id)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| return user_dir | |
| async def read_users_me(current_user: User = Depends(get_current_user)): | |
| return current_user | |
| async def read_index(): | |
| async with aiofiles.open("index.html", mode="r") as f: | |
| content = await f.read() | |
| return HTMLResponse(content=content) | |
| async def get_user_credits( | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Get user credit information | |
| user_credit = db.query(UserCreditModel).filter( | |
| UserCreditModel.user_id == current_user.id | |
| ).first() | |
| if not user_credit: | |
| # Create new credit record if it doesn't exist | |
| user_credit = UserCreditModel( | |
| user_id=current_user.id, | |
| minutes_used=0, | |
| minutes_quota=10, | |
| last_updated=datetime.utcnow().isoformat() | |
| ) | |
| db.add(user_credit) | |
| db.commit() | |
| db.refresh(user_credit) | |
| return { | |
| "user_id": user_credit.user_id, | |
| "minutes_used": user_credit.minutes_used, | |
| "minutes_quota": user_credit.minutes_quota, | |
| "minutes_remaining": max(0, user_credit.minutes_quota - user_credit.minutes_used), | |
| "last_updated": user_credit.last_updated | |
| } | |
| except Exception as e: | |
| print(f"Error getting user credits: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_all_user_credits( | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Get all user credit information | |
| user_credits = db.query(UserCreditModel).all() | |
| # Get usernames for each user_id | |
| user_data = {} | |
| for user in db.query(UserModel).all(): | |
| user_data[user.id] = user.username | |
| return [ | |
| { | |
| "user_id": credit.user_id, | |
| "username": user_data.get(credit.user_id, "Unknown"), | |
| "minutes_used": credit.minutes_used, | |
| "minutes_quota": credit.minutes_quota, | |
| "minutes_remaining": max(0, credit.minutes_quota - credit.minutes_used), | |
| "last_updated": credit.last_updated | |
| } | |
| for credit in user_credits | |
| ] | |
| except Exception as e: | |
| print(f"Error getting all user credits: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def update_user_credits( | |
| user_id: int, | |
| credit_update: dict, | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Validate input | |
| if not any(field in credit_update for field in ["minutes_used", "minutes_quota"]): | |
| raise HTTPException(status_code=400, detail="At least one of minutes_used or minutes_quota is required") | |
| # Get user credit record | |
| user_credit = db.query(UserCreditModel).filter( | |
| UserCreditModel.user_id == user_id | |
| ).first() | |
| if not user_credit: | |
| # Create new credit record if it doesn't exist | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| minutes_used = credit_update.get("minutes_used", 0) | |
| minutes_quota = credit_update.get("minutes_quota", 60) # Default quota | |
| if not isinstance(minutes_used, int) or minutes_used < 0: | |
| raise HTTPException(status_code=400, detail="minutes_used must be a non-negative integer") | |
| if not isinstance(minutes_quota, int) or minutes_quota < 0: | |
| raise HTTPException(status_code=400, detail="minutes_quota must be a non-negative integer") | |
| user_credit = UserCreditModel( | |
| user_id=user_id, | |
| minutes_used=minutes_used, | |
| minutes_quota=minutes_quota, | |
| last_updated=datetime.utcnow().isoformat() | |
| ) | |
| db.add(user_credit) | |
| else: | |
| # Update existing credit record | |
| if "minutes_used" in credit_update: | |
| minutes_used = credit_update["minutes_used"] | |
| if not isinstance(minutes_used, int) or minutes_used < 0: | |
| raise HTTPException(status_code=400, detail="minutes_used must be a non-negative integer") | |
| user_credit.minutes_used = minutes_used | |
| if "minutes_quota" in credit_update: | |
| minutes_quota = credit_update["minutes_quota"] | |
| if not isinstance(minutes_quota, int) or minutes_quota < 0: | |
| raise HTTPException(status_code=400, detail="minutes_quota must be a non-negative integer") | |
| user_credit.minutes_quota = minutes_quota | |
| user_credit.last_updated = datetime.utcnow().isoformat() | |
| db.commit() | |
| db.refresh(user_credit) | |
| # Get username for response | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| username = user.username if user else "Unknown" | |
| return { | |
| "user_id": user_credit.user_id, | |
| "username": username, | |
| "minutes_used": user_credit.minutes_used, | |
| "minutes_quota": user_credit.minutes_quota, | |
| "minutes_remaining": max(0, user_credit.minutes_quota - user_credit.minutes_used), | |
| "last_updated": user_credit.last_updated | |
| } | |
| except Exception as e: | |
| print(f"Error updating user credits: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reset_user_quota( | |
| user_id: int, | |
| quota_data: dict, | |
| current_user: User = Depends(get_current_active_admin), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| new_quota = quota_data.get("new_quota") | |
| if new_quota is None: | |
| raise HTTPException(status_code=400, detail="new_quota field is required") | |
| if not isinstance(new_quota, int) or new_quota < 0: | |
| raise HTTPException(status_code=400, detail="new_quota must be a non-negative integer") | |
| # Get user credit record | |
| user_credit = db.query(UserCreditModel).filter( | |
| UserCreditModel.user_id == user_id | |
| ).first() | |
| if not user_credit: | |
| # Create new credit record if it doesn't exist | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| user_credit = UserCreditModel( | |
| user_id=user_id, | |
| minutes_used=0, | |
| minutes_quota=new_quota, | |
| last_updated=datetime.utcnow().isoformat() | |
| ) | |
| db.add(user_credit) | |
| else: | |
| # Update existing credit record | |
| user_credit.minutes_quota = new_quota | |
| user_credit.last_updated = datetime.utcnow().isoformat() | |
| db.commit() | |
| db.refresh(user_credit) | |
| # Get username for response | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| username = user.username if user else "Unknown" | |
| return { | |
| "user_id": user_credit.user_id, | |
| "username": username, | |
| "minutes_used": user_credit.minutes_used, | |
| "minutes_quota": user_credit.minutes_quota, | |
| "minutes_remaining": max(0, user_credit.minutes_quota - user_credit.minutes_used), | |
| "last_updated": user_credit.last_updated | |
| } | |
| except Exception as e: | |
| print(f"Error resetting user quota: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_audio( | |
| request: Request, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| form_data = await request.form() | |
| file = form_data["file"] | |
| file_id = str(uuid.uuid4()) | |
| filename = f"{file_id}_{file.filename}" | |
| # Get user-specific upload directory | |
| user_upload_dir = get_user_upload_dir(current_user.id) | |
| # Create the full file path in the user's directory | |
| file_path = os.path.join(user_upload_dir, filename) | |
| print(f"file_path: ${file_path}") | |
| contents = await file.read() | |
| async with aiofiles.open(file_path, "wb") as f: | |
| await f.write(contents) | |
| return JSONResponse({"filename": filename}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def create_upload_file( | |
| current_user: User = Depends(get_current_user), | |
| file: UploadFile = File(...), | |
| model: str = Form(default="whisper-large-v3-turbo"), | |
| language: str = Form(default="auto"), | |
| temperature: float = Form(default=0.0), | |
| response_format: str = Form(default="verbose_json"), | |
| prompt: str = Form(default=""), | |
| chunk_size: int = Form(default=600), | |
| overlap: int = Form(default=5), | |
| start_time: float = Form(default=0.0), | |
| end_time: float = Form(default=None), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Validate input | |
| if not file.content_type.startswith('audio/'): | |
| raise HTTPException(400, "Only audio files allowed") | |
| # Save original file | |
| temp_dir = tempfile.gettempdir() | |
| # os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, file.filename) | |
| temp_path = f"{temp_dir}/{file.filename}" | |
| async with aiofiles.open(temp_path, "wb") as f: | |
| await f.write(await file.read()) | |
| # Read the audio file | |
| audio_data, sample_rate = sf.read(temp_path) | |
| # Convert stereo to mono if needed | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Calculate time boundaries in samples | |
| start_sample = int(start_time * sample_rate) | |
| end_sample = int(end_time * sample_rate) if end_time else len(audio_data) | |
| # Validate boundaries | |
| if start_sample >= len(audio_data): | |
| raise HTTPException(400, "Start time exceeds audio duration") | |
| if end_sample > len(audio_data): | |
| end_sample = len(audio_data) | |
| if start_sample >= end_sample: | |
| raise HTTPException(400, "Invalid time range") | |
| # Calculate audio duration in minutes | |
| audio_duration_seconds = (end_sample - start_sample) / sample_rate | |
| audio_duration_minutes = math.ceil(audio_duration_seconds / 60) # Round up to nearest minute | |
| # Check user quota | |
| user_credit = db.query(UserCreditModel).filter( | |
| UserCreditModel.user_id == current_user.id | |
| ).first() | |
| if not user_credit: | |
| # Create new credit record if it doesn't exist | |
| user_credit = UserCreditModel( | |
| user_id=current_user.id, | |
| minutes_used=0, | |
| minutes_quota=60, # Default quota | |
| last_updated=datetime.utcnow().isoformat() | |
| ) | |
| db.add(user_credit) | |
| db.commit() | |
| db.refresh(user_credit) | |
| # Check if user has enough quota | |
| minutes_remaining = max(0, user_credit.minutes_quota - user_credit.minutes_used) | |
| if audio_duration_minutes > minutes_remaining: | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"Quota exceeded. You have {minutes_remaining} minutes remaining, but this audio requires {audio_duration_minutes} minutes." | |
| ) | |
| # Update user credits | |
| user_credit.minutes_used += audio_duration_minutes | |
| user_credit.last_updated = datetime.utcnow().isoformat() | |
| db.commit() | |
| # Slice the audio data | |
| audio_data = audio_data[start_sample:end_sample] | |
| # Save the sliced audio | |
| sliced_path = f"{temp_path}_sliced.flac" | |
| sf.write(sliced_path, audio_data, sample_rate, format='FLAC') | |
| # check sliced_path file size if it exceeded the limit | |
| file_size = os.path.getsize(sliced_path) | |
| if file_size > MAX_FILE_SIZE: | |
| print(f"Slice: {sliced_path} exceeds the limit of {MAX_FILE_SIZE / (1024 * 1024)}MB") | |
| # raise HTTPException(400, f"File size exceeds the limit of {MAX_FILE_SIZE / (1024 * 1024)}MB") | |
| # Split into FLAC chunks | |
| chunks = split_audio_chunks( | |
| sliced_path, | |
| chunk_size=chunk_size, | |
| overlap=overlap | |
| ) | |
| # Process chunks concurrently | |
| semaphore = asyncio.Semaphore(3) | |
| tasks = [process_chunk( | |
| chunk_path=chunk["path"], | |
| model=model, | |
| language=language, | |
| temperature=temperature, | |
| response_format=response_format, | |
| prompt=prompt, | |
| chunk_offset=chunk["start"] + start_time, # Adjust offset to include start_time | |
| semaphore=semaphore | |
| ) for chunk in chunks] | |
| results = await asyncio.gather(*tasks) | |
| # Combine results | |
| combined = combine_results(results, response_format) | |
| # Add credit usage information to response | |
| combined["metadata"]["credit_usage"] = { | |
| "minutes_used": audio_duration_minutes, | |
| "total_minutes_used": user_credit.minutes_used, | |
| "minutes_quota": user_credit.minutes_quota, | |
| "minutes_remaining": max(0, user_credit.minutes_quota - user_credit.minutes_used) | |
| } | |
| # Cleanup | |
| for chunk in chunks: | |
| os.remove(chunk["path"]) | |
| os.remove(temp_path) | |
| os.remove(sliced_path) | |
| return combined | |
| except Exception as e: | |
| # Cleanup in case of error | |
| if 'temp_path' in locals() and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| if 'sliced_path' in locals() and os.path.exists(sliced_path): | |
| os.remove(sliced_path) | |
| raise HTTPException(500, str(e)) | |
| async def upload_audio_chunk( | |
| file: UploadFile = File(...), | |
| upload_id: str = Form(...), | |
| offset: int = Form(...), | |
| total_size: int = Form(...), | |
| current_user: User = Depends(get_current_user) | |
| ): | |
| # Create upload directory for this file in system temp directory | |
| temp_dir = tempfile.gettempdir() | |
| upload_path = os.path.join(temp_dir, f"transtudio_upload_{upload_id}") | |
| os.makedirs(upload_path, exist_ok=True) | |
| # Path for this specific chunk | |
| chunk_path = os.path.join(upload_path, f"chunk_{offset}") | |
| # Save the chunk | |
| with open(chunk_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Update metadata file with progress info | |
| metadata_path = os.path.join(upload_path, "metadata.txt") | |
| with open(metadata_path, "w") as f: | |
| f.write(f"{offset + file.size}/{total_size}") | |
| return {"success": True, "received_bytes": offset + file.size} | |
| async def finalize_audio_upload( | |
| upload_info: dict, | |
| current_user: User = Depends(get_current_user) | |
| ): | |
| upload_id = upload_info.get("upload_id") | |
| if not upload_id: | |
| raise HTTPException(status_code=400, detail="Upload ID is required") | |
| temp_dir = tempfile.gettempdir() | |
| upload_path = os.path.join(temp_dir, f"transtudio_upload_{upload_id}") | |
| if not os.path.exists(upload_path): | |
| raise HTTPException(status_code=404, detail="Upload not found") | |
| # Read metadata to get original filename if available | |
| metadata_path = os.path.join(upload_path, "metadata.txt") | |
| filename = f"{uuid.uuid4()}.mp3" # Default filename | |
| # Get user-specific upload directory | |
| user_upload_dir = get_user_upload_dir(current_user.id) | |
| # Combine all chunks into final file | |
| output_path = os.path.join(user_upload_dir, filename) | |
| with open(output_path, "wb") as outfile: | |
| # Get all chunks sorted by offset | |
| chunks = [f for f in os.listdir(upload_path) if f.startswith("chunk_")] | |
| chunks.sort(key=lambda x: int(x.split("_")[1])) | |
| for chunk_name in chunks: | |
| chunk_path = os.path.join(upload_path, chunk_name) | |
| with open(chunk_path, "rb") as infile: | |
| shutil.copyfileobj(infile, outfile) | |
| # Clean up temporary files | |
| shutil.rmtree(upload_path) | |
| return {"success": True, "filename": filename} | |
| async def cancel_audio_upload( | |
| upload_info: dict, | |
| current_user: User = Depends(get_current_user) | |
| ): | |
| upload_id = upload_info.get("upload_id") | |
| if not upload_id: | |
| raise HTTPException(status_code=400, detail="Upload ID is required") | |
| temp_dir = tempfile.gettempdir() | |
| upload_path = os.path.join(temp_dir, f"transtudio_upload_{upload_id}") | |
| if os.path.exists(upload_path): | |
| shutil.rmtree(upload_path) | |
| return {"success": True} | |
| async def get_audio_files( | |
| current_user: User = Depends(get_current_user) | |
| ): | |
| try: | |
| file_list = [] | |
| if current_user.is_admin: | |
| # Admin can see all files | |
| for user_id in os.listdir(UPLOAD_DIR): | |
| user_dir = os.path.join(UPLOAD_DIR, user_id) | |
| if os.path.isdir(user_dir): | |
| # Get username for this user_id | |
| db = next(get_db()) | |
| user = db.query(UserModel).filter(UserModel.id == int(user_id)).first() | |
| username = user.username if user else f"User {user_id}" | |
| for filename in os.listdir(user_dir): | |
| file_path = os.path.join(user_dir, filename) | |
| if os.path.isfile(file_path) and not filename.startswith('.'): | |
| # Get file stats | |
| file_stats = os.stat(file_path) | |
| file_size = os.path.getsize(file_path) | |
| file_list.append({ | |
| "id": filename, | |
| "filename": filename, | |
| "original_filename": filename, | |
| "size": file_size, | |
| "uploaded_at": datetime.fromtimestamp(file_stats.st_mtime).isoformat(), | |
| "user_id": int(user_id), | |
| "username": username | |
| }) | |
| else: | |
| # Regular users can only see their own files | |
| user_dir = get_user_upload_dir(current_user.id) | |
| for filename in os.listdir(user_dir): | |
| file_path = os.path.join(user_dir, filename) | |
| if os.path.isfile(file_path) and not filename.startswith('.'): | |
| # Get file stats | |
| file_stats = os.stat(file_path) | |
| file_size = os.path.getsize(file_path) | |
| file_list.append({ | |
| "id": filename, | |
| "filename": filename, | |
| "original_filename": filename, | |
| "size": file_size, | |
| "uploaded_at": datetime.fromtimestamp(file_stats.st_mtime).isoformat() | |
| }) | |
| # Sort by upload date (newest first) | |
| file_list.sort(key=lambda x: x["uploaded_at"], reverse=True) | |
| return file_list | |
| except Exception as e: | |
| print(f"Error getting audio files: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_audio_file( | |
| filename: str, | |
| user_id: int = None, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Determine which user's file to delete | |
| target_user_id = user_id if current_user.is_admin and user_id else current_user.id | |
| # Check if user exists if admin is deleting another user's file | |
| if current_user.is_admin and user_id and user_id != current_user.id: | |
| user = db.query(UserModel).filter(UserModel.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Get the user's upload directory | |
| user_upload_dir = get_user_upload_dir(target_user_id) | |
| # Check if file exists in the user's directory | |
| file_path = os.path.join(user_upload_dir, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Audio file not found") | |
| # Delete the physical file | |
| os.remove(file_path) | |
| return {"status": "success", "message": "Audio file deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error deleting audio file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def extract_audio(file: UploadFile = File(...)): | |
| # Save uploaded video temporarily | |
| video_path = f"temp/{file.filename}" | |
| audio_path = f"temp/{os.path.splitext(file.filename)[0]}.mp3" | |
| with open(video_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| # Extract audio using FFmpeg | |
| subprocess.run([ | |
| "ffmpeg", "-i", video_path, | |
| "-q:a", "0", "-map", "a", audio_path | |
| ]) | |
| # Clean up video file | |
| os.remove(video_path) | |
| # Return audio file | |
| return FileResponse(audio_path) | |
| async def get_transcriptions( | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Query transcriptions from database instead of reading files | |
| transcriptions = db.query(TranscriptionModel).filter( | |
| TranscriptionModel.user_id == current_user.id | |
| ).all() | |
| return [ | |
| { | |
| "id": t.id, | |
| "name": t.name, | |
| "text": t.text, | |
| "segments": json.loads(t.segments) if t.segments else [], | |
| "audio_file": t.audio_file | |
| } | |
| for t in transcriptions | |
| ] | |
| except Exception as e: | |
| print(f"Error getting transcriptions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def save_transcription( | |
| data: dict, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| trans_id = str(uuid.uuid4()) | |
| # Create new transcription in database | |
| transcription = TranscriptionModel( | |
| id=trans_id, | |
| name=f"t{trans_id[:8]}", | |
| text=data["text"], | |
| segments=json.dumps(data.get("segments", [])), | |
| audio_file=data.get("audio_file"), | |
| user_id=current_user.id | |
| ) | |
| db.add(transcription) | |
| db.commit() | |
| return {"status": "success", "id": trans_id} | |
| except Exception as e: | |
| db.rollback() | |
| print(f"Error saving transcription: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_transcription( | |
| trans_id: str, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Query from database instead of reading file | |
| transcription = db.query(TranscriptionModel).filter( | |
| TranscriptionModel.id == trans_id, | |
| TranscriptionModel.user_id == current_user.id | |
| ).first() | |
| if not transcription: | |
| raise HTTPException(status_code=404, detail="Transcription not found") | |
| return { | |
| "id": transcription.id, | |
| "name": transcription.name, | |
| "text": transcription.text, | |
| "segments": json.loads(transcription.segments) if transcription.segments else [], | |
| "audio_file": transcription.audio_file | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error getting transcription: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_transcription( | |
| trans_id: str, | |
| current_user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| try: | |
| # Query and delete from database | |
| transcription = db.query(TranscriptionModel).filter( | |
| TranscriptionModel.id == trans_id, | |
| TranscriptionModel.user_id == current_user.id | |
| ).first() | |
| if not transcription: | |
| raise HTTPException(status_code=404, detail="Transcription not found") | |
| db.delete(transcription) | |
| db.commit() | |
| return {"status": "success", "message": "Transcription deleted successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| db.rollback() | |
| print(f"Error deleting transcription: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_event(): | |
| init_db() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |