Spaces:
Sleeping
Sleeping
| import time | |
| import gradio as gr | |
| import pandas as pd | |
| import openvino_genai as ov_genai | |
| from huggingface_hub import snapshot_download | |
| from threading import Lock, Event | |
| import os | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| import cpuinfo | |
| import openvino as ov | |
| import librosa | |
| from googleapiclient.discovery import build | |
| import gc | |
| from PyPDF2 import PdfReader | |
| from docx import Document | |
| import textwrap | |
| from queue import Queue, Empty | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Generator | |
| import warnings | |
| from transformers import pipeline # Added for Whisper | |
| # Suppress specific OpenVINO deprecation warning | |
| warnings.filterwarnings("ignore", category=DeprecationWarning, module="openvino.runtime") | |
| # Google API configuration | |
| GOOGLE_API_KEY = "AIzaSyAo-1iW5MEZbc53DlEldtnUnDaYuTHUDH4" | |
| GOOGLE_CSE_ID = "3027bedf3c88a4efb" | |
| DEFAULT_MAX_TOKENS = 100 | |
| DEFAULT_NUM_IMAGES = 1 | |
| MAX_HISTORY_TURNS = 3 | |
| MAX_TOKENS_LIMIT = 1000 | |
| class UnifiedAISystem: | |
| def __init__(self): | |
| self.pipe_lock = Lock() | |
| self.current_df = None | |
| self.mistral_pipe = None | |
| self.internvl_pipe = None | |
| self.whisper_pipe = None | |
| self.current_document_text = None | |
| self.generation_executor = ThreadPoolExecutor(max_workers=3) | |
| self.initialize_models() | |
| def initialize_models(self): | |
| """Initialize all required models""" | |
| # Download models if not exists | |
| model_paths = { | |
| "mistral-ov": "OpenVINO/mistral-7b-instruct-v0.1-int8-ov", | |
| "internvl-ov": "OpenVINO/InternVL2-1B-int8-ov" | |
| # Removed distil-whisper download since we're using transformers version | |
| } | |
| for local_dir, repo_id in model_paths.items(): | |
| if not os.path.exists(local_dir): | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
| # CPU-specific configuration | |
| cpu_features = cpuinfo.get_cpu_info()['flags'] | |
| config_properties = {} | |
| if 'avx512' in cpu_features: | |
| config_properties["ENFORCE_BF16"] = "YES" | |
| elif 'avx2' in cpu_features: | |
| config_properties["INFERENCE_PRECISION_HINT"] = "f32" | |
| # Initialize Mistral model with updated configuration | |
| self.mistral_pipe = ov_genai.LLMPipeline( | |
| "mistral-ov", | |
| device="CPU", | |
| PERFORMANCE_HINT="THROUGHPUT", | |
| **config_properties | |
| ) | |
| def load_data(self, file_path): | |
| """Load student data from file""" | |
| try: | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext == '.csv': | |
| self.current_df = pd.read_csv(file_path) | |
| elif file_ext in ['.xlsx', '.xls']: | |
| self.current_df = pd.read_excel(file_path) | |
| else: | |
| return False, "β Unsupported file format. Please upload a .csv or .xlsx file." | |
| return True, f"β Loaded {len(self.current_df)} records from {os.path.basename(file_path)}" | |
| except Exception as e: | |
| return False, f"β Error loading file: {str(e)}" | |
| def extract_text_from_document(self, file_path): | |
| """Extract text from PDF or DOCX documents""" | |
| text = "" | |
| try: | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext == '.pdf': | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| elif file_ext == '.docx': | |
| doc = Document(file_path) | |
| for para in doc.paragraphs: | |
| text += para.text + "\n" | |
| else: | |
| return False, "β Unsupported document format. Please upload PDF or DOCX." | |
| # Clean and format text | |
| text = text.replace('\x0c', '') # Remove form feed characters | |
| text = textwrap.dedent(text) # Remove common leading whitespace | |
| self.current_document_text = text | |
| return True, f"β Extracted text from {os.path.basename(file_path)}" | |
| except Exception as e: | |
| return False, f"β Error processing document: {str(e)}" | |
| def generate_text_stream(self, prompt: str, max_tokens: int) -> Generator[str, None, None]: | |
| """Unified text generation with queued token streaming""" | |
| start_time = time.time() | |
| response_queue = Queue() | |
| completion_event = Event() | |
| error = [None] # Use list to capture exception from thread | |
| optimized_config = ov_genai.GenerationConfig( | |
| max_new_tokens=max_tokens, | |
| temperature=0.3, | |
| top_p=0.9, | |
| streaming=True, | |
| streaming_interval=5 # Batch tokens in groups of 5 | |
| ) | |
| def callback(tokens): # Accepts multiple tokens | |
| response_queue.put("".join(tokens)) | |
| return ov_genai.StreamingStatus.RUNNING | |
| def generate(): | |
| try: | |
| with self.pipe_lock: | |
| self.mistral_pipe.generate(prompt, optimized_config, callback) | |
| except Exception as e: | |
| error[0] = str(e) | |
| finally: | |
| completion_event.set() | |
| # Submit generation task to executor | |
| self.generation_executor.submit(generate) | |
| accumulated = [] | |
| token_count = 0 | |
| last_gc = time.time() | |
| while not completion_event.is_set() or not response_queue.empty(): | |
| if error[0]: | |
| yield f"β Error: {error[0]}" | |
| print(f"Stream generation time: {time.time() - start_time:.2f} seconds") | |
| return | |
| try: | |
| token_batch = response_queue.get(timeout=0.1) | |
| accumulated.append(token_batch) | |
| token_count += len(token_batch) | |
| yield "".join(accumulated) | |
| # Periodic garbage collection | |
| if time.time() - last_gc > 2.0: | |
| gc.collect() | |
| last_gc = time.time() | |
| except Empty: | |
| continue | |
| print(f"Generated {token_count} tokens in {time.time() - start_time:.2f} seconds " | |
| f"({token_count/(time.time() - start_time):.2f} tokens/sec)") | |
| yield "".join(accumulated) | |
| def analyze_student_data(self, query, max_tokens=500): | |
| """Analyze student data using AI with streaming""" | |
| if not query or not query.strip(): | |
| yield "β οΈ Please enter a valid question" | |
| return | |
| if self.current_df is None: | |
| yield "β οΈ Please upload and load a student data file first" | |
| return | |
| data_summary = self._prepare_data_summary(self.current_df) | |
| prompt = f"""You are an expert education analyst. Analyze the following student performance data: | |
| {data_summary} | |
| Question: {query} | |
| Please include: | |
| 1. Direct answer to the question | |
| 2. Relevant statistics | |
| 3. Key insights | |
| 4. Actionable recommendations | |
| Format the output with clear headings""" | |
| # Use unified streaming generator | |
| yield from self.generate_text_stream(prompt, max_tokens) | |
| def _prepare_data_summary(self, df): | |
| """Summarize the uploaded data""" | |
| summary = f"Student performance data with {len(df)} rows and {len(df.columns)} columns.\n" | |
| summary += "Columns: " + ", ".join(df.columns) + "\n" | |
| summary += "First 3 rows:\n" + df.head(3).to_string(index=False) | |
| return summary | |
| def analyze_image(self, image, url, prompt): | |
| """Analyze image with InternVL model (synchronous, no streaming)""" | |
| try: | |
| if image is not None: | |
| image_source = image | |
| elif url and url.startswith(("http://", "https://")): | |
| response = requests.get(url) | |
| image_source = Image.open(BytesIO(response.content)).convert("RGB") | |
| else: | |
| return "β οΈ Please upload an image or enter a valid URL" | |
| # Convert to OpenVINO tensor | |
| image_data = np.array(image_source.getdata()).reshape( | |
| 1, image_source.size[1], image_source.size[0], 3 | |
| ).astype(np.byte) | |
| image_tensor = ov.Tensor(image_data) | |
| # Lazy initialize InternVL | |
| if self.internvl_pipe is None: | |
| self.internvl_pipe = ov_genai.VLMPipeline("internvl-ov", device="CPU") | |
| with self.pipe_lock: | |
| self.internvl_pipe.start_chat() | |
| output = self.internvl_pipe.generate(prompt, image=image_tensor, max_new_tokens=100) | |
| self.internvl_pipe.finish_chat() | |
| # Ensure output is string | |
| return str(output) | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def process_audio(self, data, sr): | |
| """Process audio data for speech recognition""" | |
| try: | |
| # Convert to mono | |
| if data.ndim > 1: | |
| data = np.mean(data, axis=1) # Simple mono conversion | |
| else: | |
| data = data | |
| # Convert to float32 and normalize | |
| data = data.astype(np.float32) | |
| max_val = np.max(np.abs(data)) + 1e-7 | |
| data /= max_val | |
| # Simple noise reduction | |
| data = np.clip(data, -0.5, 0.5) | |
| # Trim silence | |
| energy = np.abs(data) | |
| threshold = np.percentile(energy, 25) # Simple threshold | |
| mask = energy > threshold | |
| indices = np.where(mask)[0] | |
| if len(indices) > 0: | |
| start = max(0, indices[0] - 1000) | |
| end = min(len(data), indices[-1] + 1000) | |
| data = data[start:end] | |
| # Resample if needed using simpler method | |
| if sr != 16000: | |
| # Calculate new length | |
| new_length = int(len(data) * 16000 / sr) | |
| # Linear interpolation for resampling | |
| data = np.interp( | |
| np.linspace(0, len(data)-1, new_length), | |
| np.arange(len(data)), | |
| data | |
| ) | |
| sr = 16000 | |
| return data | |
| except Exception as e: | |
| print(f"Audio processing error: {e}") | |
| return np.array([], dtype=np.float32) | |
| def transcribe(self, audio): | |
| """Transcribe audio using OpenAI Whisper-small model""" | |
| if audio is None: | |
| return "" | |
| sr, data = audio | |
| # Skip if audio is too short (less than 0.5 seconds) | |
| if len(data)/sr < 0.5: | |
| return "" | |
| try: | |
| processed = self.process_audio(data, sr) | |
| # Skip if audio is still too short after processing | |
| if len(processed) < 8000: # 0.5 seconds at 16kHz | |
| return "" | |
| # Lazy initialize Whisper - USING TRANSFORMERS PIPELINE | |
| if self.whisper_pipe is None: | |
| self.whisper_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| device="cpu" # Use CPU for consistency | |
| ) | |
| # Use transformers pipeline for transcription | |
| result = self.whisper_pipe(processed, return_timestamps=False) | |
| return result["text"] | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return "β Transcription failed - please try again" | |
| def generate_lesson_plan(self, topic, duration, additional_instructions="", max_tokens=1200): | |
| """Generate a lesson plan based on document content""" | |
| if not topic: | |
| yield "β οΈ Please enter a lesson topic" | |
| return | |
| if not self.current_document_text: | |
| yield "β οΈ Please upload and process a document first" | |
| return | |
| prompt = f"""As an expert educator, create a focused lesson plan using the provided content. | |
| **Core Requirements:** | |
| 1. TOPIC: {topic} | |
| 2. TOTAL DURATION: {duration} periods | |
| 3. ADDITIONAL INSTRUCTIONS: {additional_instructions or 'None'} | |
| **Content Summary:** | |
| {self.current_document_text[:2500]}... [truncated] | |
| **Output Structure:** | |
| 1. PERIOD ALLOCATION (Break topic into {duration} logical segments): | |
| - Period 1: [Subtopic 1] | |
| - Period 2: [Subtopic 2] | |
| ... | |
| 2. LEARNING OBJECTIVES (Max 3 bullet points) | |
| 3. TEACHING ACTIVITIES (One engaging method per period) | |
| 4. RESOURCES (Key materials from document) | |
| 5. ASSESSMENT (Simple checks for understanding) | |
| 6. PAGE REFERENCES (Specific source pages) | |
| **Key Rules:** | |
| - Strictly divide content into exactly {duration} periods | |
| - Prioritize document content over creativity | |
| - Keep objectives measurable | |
| - Use only document resources | |
| - Make page references specific""" | |
| # Use unified streaming generator | |
| yield from self.generate_text_stream(prompt, max_tokens) | |
| def fetch_images(self, query: str, num: int = DEFAULT_NUM_IMAGES) -> list: | |
| """Fetch unique images by requesting different result pages""" | |
| if num <= 0: | |
| return [] | |
| try: | |
| service = build("customsearch", "v1", developerKey=GOOGLE_API_KEY) | |
| image_links = [] | |
| seen_urls = set() # To track unique URLs | |
| # Start from different positions to get unique images | |
| for start_index in range(1, num * 2, 2): | |
| if len(image_links) >= num: | |
| break | |
| res = service.cse().list( | |
| q=query, | |
| cx=GOOGLE_CSE_ID, | |
| searchType="image", | |
| num=1, | |
| start=start_index | |
| ).execute() | |
| if "items" in res and res["items"]: | |
| item = res["items"][0] | |
| # Skip duplicates | |
| if item["link"] not in seen_urls: | |
| image_links.append(item["link"]) | |
| seen_urls.add(item["link"]) | |
| return image_links[:num] | |
| except Exception as e: | |
| print(f"Error in image fetching: {e}") | |
| return [] | |
| # Initialize global object | |
| ai_system = UnifiedAISystem() | |
| # CSS styles with improved output box | |
| css = """ | |
| .gradio-container { | |
| background-color: #121212; | |
| color: #fff; | |
| } | |
| .user-msg, .bot-msg { | |
| padding: 12px 16px; | |
| border-radius: 18px; | |
| margin: 8px 0; | |
| line-height: 1.5; | |
| border: none; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .user-msg { | |
| background: linear-gradient(135deg, #4a5568, #2d3748); | |
| color: white; | |
| margin-left: 20%; | |
| border-bottom-right-radius: 5px; | |
| border: none; | |
| } | |
| .bot-msg { | |
| background: linear-gradient(135deg, #2d3748, #1a202c); | |
| color: white; | |
| margin-right: 20%; | |
| border-bottom-left-radius: 5px; | |
| border: none; | |
| } | |
| /* Remove top border from chat messages */ | |
| .user-msg, .bot-msg { | |
| border-top: none !important; | |
| } | |
| /* Remove borders from chat container */ | |
| .chatbot > div { | |
| border: none !important; | |
| } | |
| .chatbot .message { | |
| border: none !important; | |
| } | |
| /* Improve scrollbar */ | |
| .chatbot::-webkit-scrollbar { | |
| width: 8px; | |
| } | |
| .chatbot::-webkit-scrollbar-track { | |
| background: #2a2a2a; | |
| border-radius: 4px; | |
| } | |
| .chatbot::-webkit-scrollbar-thumb { | |
| background: #4a5568; | |
| border-radius: 4px; | |
| } | |
| .chatbot::-webkit-scrollbar-thumb:hover { | |
| background: #5a6578; | |
| } | |
| /* Rest of the CSS remains the same */ | |
| .gradio-container { | |
| background-color: #121212; | |
| color: #fff; | |
| } | |
| .upload-box { | |
| background-color: #333; | |
| border-radius: 8px; | |
| padding: 16px; | |
| margin-bottom: 16px; | |
| } | |
| #question-input { | |
| background-color: #333; | |
| color: #fff; | |
| border-radius: 8px; | |
| padding: 12px; | |
| border: 1px solid #555; | |
| } | |
| .mode-checkbox { | |
| background-color: #333; | |
| color: #fff; | |
| border: 1px solid #555; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin: 5px; | |
| } | |
| .slider-container { | |
| margin-top: 20px; | |
| padding: 15px; | |
| border-radius: 10px; | |
| background-color: #2a2a2a; | |
| } | |
| .system-info { | |
| background-color: #7B9BDB; | |
| padding: 15px; | |
| border-radius: 8px; | |
| margin: 15px 0; | |
| border-left: 4px solid #1890ff; | |
| } | |
| .chat-image { | |
| cursor: pointer; | |
| transition: transform 0.2s; | |
| max-height: 100px; | |
| margin: 4px; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .chat-image:hover { | |
| transform: scale(1.05); | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
| } | |
| .modal { | |
| position: fixed; | |
| top: 0; | |
| left: 0; | |
| width: 100%; | |
| height: 100%; | |
| background: rgba(0,0,0,0.8); | |
| display: none; | |
| z-index: 1000; | |
| cursor: zoom-out; | |
| } | |
| .modal-content { | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| max-width: 90%; | |
| max-height: 90%; | |
| background: white; | |
| padding: 10px; | |
| border-radius: 12px; | |
| } | |
| .modal-img { | |
| width: auto; | |
| height: auto; | |
| max-width: 100%; | |
| max-height: 100%; | |
| border-radius: 8px; | |
| } | |
| .typing-indicator { | |
| display: inline-block; | |
| position: relative; | |
| width: 40px; | |
| height: 20px; | |
| } | |
| .typing-dot { | |
| display: inline-block; | |
| width: 6px; | |
| height: 6px; | |
| border-radius: 50%; | |
| background-color: #fff; | |
| position: absolute; | |
| animation: typing 1.4s infinite ease-in-out; | |
| } | |
| .typing-dot:nth-child(1) { | |
| left: 0; | |
| animation-delay: 0s; | |
| } | |
| .typing-dot:nth-child(2) { | |
| left: 12px; | |
| animation-delay: 0.2s; | |
| } | |
| .typing-dot:nth-child(3) { | |
| left: 24px; | |
| animation-delay: 0.4s; | |
| } | |
| @keyframes typing { | |
| 0%, 60%, 100% { transform: translateY(0); } | |
| 30% { transform: translateY(-5px); } | |
| } | |
| .lesson-plan { | |
| background: linear-gradient(135deg, #1a202c, #2d3748); | |
| padding: 15px; | |
| border-radius: 12px; | |
| margin: 10px 0; | |
| border-left: 4px solid #4a9df0; | |
| } | |
| .lesson-section { | |
| margin-bottom: 15px; | |
| padding-bottom: 10px; | |
| border-bottom: 1px solid #4a5568; | |
| } | |
| .lesson-title { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| color: #4a9df0; | |
| margin-bottom: 8px; | |
| } | |
| .page-ref { | |
| background-color: #4a5568; | |
| padding: 3px 8px; | |
| border-radius: 4px; | |
| font-size: 0.9em; | |
| display: inline-block; | |
| margin: 3px; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="Unified EDU Assistant") as demo: | |
| gr.Markdown("# π€ Unified EDU Assistant by Phanindra Reddy K") | |
| # System info banner | |
| gr.HTML(""" | |
| <div class="system-info"> | |
| <strong>Multi-Modal AI Assistant</strong> | |
| <ul> | |
| <li>Text & Voice Chat with Mistral-7B</li> | |
| <li>Image Understanding with InternVL</li> | |
| <li>Student Data Analysis</li> | |
| <li>Visual Search with Google Images</li> | |
| <li>Lesson Planning from Documents</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Modal for image preview | |
| modal_html = """ | |
| <div class="modal" id="imageModal" onclick="this.style.display='none'"> | |
| <div class="modal-content"> | |
| <img class="modal-img" id="expandedImg"> | |
| </div> | |
| </div> | |
| <script> | |
| function showImage(url) { | |
| document.getElementById('expandedImg').src = url; | |
| document.getElementById('imageModal').style.display = 'block'; | |
| } | |
| </script> | |
| """ | |
| gr.HTML(modal_html) | |
| chat_state = gr.State([]) | |
| with gr.Column(scale=2, elem_classes="chat-container"): | |
| chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, | |
| avatar_images=("user.png", "bot.png"), show_label=False) | |
| # Mode selection | |
| with gr.Row(): | |
| chat_mode = gr.Checkbox(label="π¬ General Chat", value=True, elem_classes="mode-checkbox") | |
| student_mode = gr.Checkbox(label="π Student Analytics", value=False, elem_classes="mode-checkbox") | |
| image_mode = gr.Checkbox(label="πΌοΈ Image Analysis", value=False, elem_classes="mode-checkbox") | |
| lesson_mode = gr.Checkbox(label="π Lesson Planning", value=False, elem_classes="mode-checkbox") | |
| # Dynamic input fields (General Chat by default) | |
| with gr.Column() as chat_inputs: | |
| include_images = gr.Checkbox(label="Include Visuals", value=True) | |
| user_input = gr.Textbox( | |
| placeholder="Type your question here...", | |
| label="Your Question", | |
| container=False, | |
| elem_id="question-input" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| value=100, | |
| step=10, | |
| label="Response Length (Tokens)" | |
| ) | |
| num_images = gr.Slider( | |
| minimum=0, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| label="Number of Images", | |
| visible=True | |
| ) | |
| # Student inputs | |
| with gr.Column(visible=False) as student_inputs: | |
| file_upload = gr.File(label="CSV/Excel File", file_types=[".csv", ".xlsx"], type="filepath") | |
| student_question = gr.Textbox( | |
| placeholder="Ask questions about student data...", | |
| label="Your Question", | |
| elem_id="question-input" | |
| ) | |
| student_status = gr.Markdown("No file loaded") | |
| # Image analysis inputs | |
| with gr.Column(visible=False) as image_inputs: | |
| image_upload = gr.Image(type="pil", label="Upload Image") | |
| image_url = gr.Textbox( | |
| label="OR Enter Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| elem_id="question-input" | |
| ) | |
| image_question = gr.Textbox( | |
| placeholder="Ask questions about the image...", | |
| label="Your Question", | |
| elem_id="question-input" | |
| ) | |
| # Lesson planning inputs | |
| with gr.Column(visible=False) as lesson_inputs: | |
| gr.Markdown("### π Lesson Planning") | |
| doc_upload = gr.File( | |
| label="Upload Curriculum Document (PDF/DOCX)", | |
| file_types=[".pdf", ".docx"], | |
| type="filepath" | |
| ) | |
| doc_status = gr.Markdown("No document uploaded") | |
| with gr.Row(): | |
| topic_input = gr.Textbox( | |
| label="Lesson Topic", | |
| placeholder="Enter the main topic for the lesson plan" | |
| ) | |
| duration_input = gr.Number( | |
| label="Total Periods", | |
| value=5, | |
| minimum=1, | |
| maximum=20, | |
| step=1 | |
| ) | |
| additional_instructions = gr.Textbox( | |
| label="Additional Requirements (optional)", | |
| placeholder="Specific teaching methods, resources, or special considerations..." | |
| ) | |
| generate_btn = gr.Button("Generate Lesson Plan", variant="primary") | |
| # Common controls | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| mic_btn = gr.Button("Transcribe Voice", variant="secondary") | |
| mic = gr.Audio(sources=["microphone"], type="numpy", label="Voice Input") | |
| # Event handlers | |
| def toggle_modes(chat, student, image, lesson): | |
| return [ | |
| gr.update(visible=chat), | |
| gr.update(visible=student), | |
| gr.update(visible=image), | |
| gr.update(visible=lesson) | |
| ] | |
| def load_student_file(file_path): | |
| success, message = ai_system.load_data(file_path) | |
| return message | |
| def process_document(file_path): | |
| if not file_path: | |
| return "β οΈ Please select a document first" | |
| success, message = ai_system.extract_text_from_document(file_path) | |
| return message | |
| def render_history(history): | |
| """Render chat history with images and proper formatting""" | |
| rendered = [] | |
| for user_msg, bot_msg, image_links in history: | |
| user_html = f"<div class='user-msg'>{user_msg}</div>" | |
| # Ensure bot_msg is a string before checking substrings | |
| bot_text = str(bot_msg) | |
| if "Lesson Plan:" in bot_text: | |
| bot_html = f"<div class='lesson-plan'>{bot_text}</div>" | |
| else: | |
| bot_html = f"<div class='bot-msg'>{bot_text}</div>" | |
| # Add images if available | |
| if image_links: | |
| images_html = "".join( | |
| f"<img src='{url}' class='chat-image' onclick='showImage(\"{url}\")' />" | |
| for url in image_links | |
| ) | |
| bot_html += f"<br><br><b>πΈ Related Visuals:</b><br><div style='display: flex; flex-wrap: wrap;'>{images_html}</div>" | |
| rendered.append((user_html, bot_html)) | |
| return rendered | |
| def respond(message, history, chat, student, image, lesson, | |
| tokens, student_q, image_q, image_upload, image_url, | |
| include_visuals, num_imgs, topic, duration, additional): | |
| """ | |
| 1. Use actual_message (depending on mode) instead of raw `message`. | |
| 2. Convert any nonβstring Bot response (like VLMDecodedResults) to str(). | |
| 3. Disable the input box during streaming, then re-enable it at the end. | |
| """ | |
| updated_history = list(history) | |
| # Determine which prompt to actually send | |
| if student: | |
| actual_message = student_q | |
| elif image: | |
| actual_message = image_q | |
| elif lesson: | |
| actual_message = f"Generate lesson plan for: {topic} ({duration} periods)" | |
| if additional: | |
| actual_message += f"\nAdditional: {additional}" | |
| else: | |
| actual_message = message | |
| # Add a βtypingβ placeholder entry using actual_message | |
| typing_html = "<div class='typing-indicator'><div class='typing-dot'></div><div class='typing-dot'></div><div class='typing-dot'></div></div>" | |
| updated_history.append((actual_message, typing_html, [])) | |
| # First yield: clear & disable the input box while streaming | |
| yield render_history(updated_history), gr.update(value="", interactive=False), updated_history | |
| full_response = "" | |
| images = [] | |
| try: | |
| if chat: | |
| # General chat mode β streaming | |
| for chunk in ai_system.generate_text_stream(actual_message, tokens): | |
| full_response = chunk | |
| updated_history[-1] = (actual_message, full_response, []) | |
| yield render_history(updated_history), gr.update(value="", interactive=False), updated_history | |
| if include_visuals: | |
| images = ai_system.fetch_images(actual_message, num_imgs) | |
| elif student: | |
| # Student analytics mode β streaming | |
| if ai_system.current_df is None: | |
| full_response = "β οΈ Please upload a student data file first" | |
| else: | |
| for chunk in ai_system.analyze_student_data(student_q, tokens): | |
| full_response = chunk | |
| updated_history[-1] = (actual_message, full_response, []) | |
| yield render_history(updated_history), gr.update(value="", interactive=False), updated_history | |
| elif image: | |
| # Image analysis mode β synchronous | |
| if (not image_upload) and (not image_url): | |
| full_response = "β οΈ Please upload an image or enter a URL" | |
| else: | |
| # ai_system.analyze_image(...) returns a VLMDecodedResults, not a string | |
| result_obj = ai_system.analyze_image(image_upload, image_url, image_q) | |
| full_response = str(result_obj) | |
| elif lesson: | |
| # Lesson planning mode β streaming | |
| if not topic: | |
| full_response = "β οΈ Please enter a lesson topic" | |
| else: | |
| duration = int(duration) if duration else 5 | |
| for chunk in ai_system.generate_lesson_plan(topic, duration, additional, tokens): | |
| full_response = chunk | |
| updated_history[-1] = (actual_message, full_response, []) | |
| yield render_history(updated_history), gr.update(value="", interactive=False), updated_history | |
| # Final update: put in images (if any), trim history, and re-enable input | |
| updated_history[-1] = (actual_message, full_response, images) | |
| if len(updated_history) > MAX_HISTORY_TURNS: | |
| updated_history = updated_history[-MAX_HISTORY_TURNS:] | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| updated_history[-1] = (actual_message, error_msg, []) | |
| # Final yield: clear & re-enable the input box | |
| yield render_history(updated_history), gr.update(value="", interactive=True), updated_history | |
| # Voice transcription | |
| def transcribe_audio(audio): | |
| return ai_system.transcribe(audio) | |
| # Mode toggles | |
| chat_mode.change(fn=toggle_modes, inputs=[chat_mode, student_mode, image_mode, lesson_mode], | |
| outputs=[chat_inputs, student_inputs, image_inputs, lesson_inputs]) | |
| student_mode.change(fn=toggle_modes, inputs=[chat_mode, student_mode, image_mode, lesson_mode], | |
| outputs=[chat_inputs, student_inputs, image_inputs, lesson_inputs]) | |
| image_mode.change(fn=toggle_modes, inputs=[chat_mode, student_mode, image_mode, lesson_mode], | |
| outputs=[chat_inputs, student_inputs, image_inputs, lesson_inputs]) | |
| lesson_mode.change(fn=toggle_modes, inputs=[chat_mode, student_mode, image_mode, lesson_mode], | |
| outputs=[chat_inputs, student_inputs, image_inputs, lesson_inputs]) | |
| # File upload handler | |
| file_upload.change(fn=load_student_file, inputs=file_upload, outputs=student_status) | |
| # Document upload handler | |
| doc_upload.change(fn=process_document, inputs=doc_upload, outputs=doc_status) | |
| mic_btn.click(fn=transcribe_audio, inputs=mic, outputs=user_input) | |
| # Submit handler | |
| submit_btn.click( | |
| fn=respond, | |
| inputs=[ | |
| user_input, chat_state, chat_mode, student_mode, image_mode, lesson_mode, | |
| max_tokens, student_question, image_question, image_upload, image_url, | |
| include_images, num_images, | |
| topic_input, duration_input, additional_instructions | |
| ], | |
| outputs=[chatbot, user_input, chat_state] | |
| ) | |
| # Lesson plan generation button | |
| generate_btn.click( | |
| fn=respond, | |
| inputs=[ | |
| gr.Textbox(value="Generate lesson plan", visible=False), # Hidden message | |
| chat_state, | |
| chat_mode, student_mode, image_mode, lesson_mode, | |
| max_tokens, | |
| gr.Textbox(visible=False), # student_q | |
| gr.Textbox(visible=False), # image_q | |
| gr.Image(visible=False), # image_upload | |
| gr.Textbox(visible=False), # image_url | |
| gr.Checkbox(visible=False), # include_visuals | |
| gr.Slider(visible=False), # num_imgs | |
| topic_input, # Pass topic | |
| duration_input, # Pass duration | |
| additional_instructions # Pass additional instructions | |
| ], | |
| outputs=[chatbot, user_input, chat_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True, show_api=False) |