import gradio as gr from transformers import AutoModelForCausalLM, AutoProcessor, pipeline from PIL import Image import torch import warnings # Suppress warnings warnings.filterwarnings("ignore") # Load Phi-3.5-vision model phi_model_id = "microsoft/Phi-3.5-vision-instruct" try: phi_model = AutoModelForCausalLM.from_pretrained( phi_model_id, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, # Use float16 to reduce memory usage _attn_implementation="eager" # Fall back to eager implementation if flash attention is not available ) except ImportError: print("FlashAttention not available, falling back to eager implementation.") phi_model = AutoModelForCausalLM.from_pretrained( phi_model_id, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, _attn_implementation="eager" ) phi_processor = AutoProcessor.from_pretrained(phi_model_id, trust_remote_code=True) # Load Llama 3.1 model llama_model_id = "meta-llama/Llama-3.1-8B" try: llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto", torch_dtype=torch.float16) except Exception as e: print(f"Error loading Llama 3.1 model: {e}") print("Falling back to a smaller, open-source model.") llama_model_id = "gpt2" # Fallback to a smaller, open-source model llama_pipeline = pipeline("text-generation", model=llama_model_id, device_map="auto") def analyze_image(image, query): prompt = f"<|user|>\n<|image_1|>\n{query}<|end|>\n<|assistant|>\n" inputs = phi_processor(prompt, images=image, return_tensors="pt").to(phi_model.device) with torch.no_grad(): output = phi_model.generate(**inputs, max_new_tokens=100) return phi_processor.decode(output[0], skip_special_tokens=True) def generate_text(query, history): context = "\n".join([f"{h[0]}\n{h[1]}" for h in history]) prompt = f"{context}\nHuman: {query}\nAI:" response = llama_pipeline(prompt, max_new_tokens=100, do_sample=True, temperature=0.7)[0]['generated_text'] return response.split("AI:")[-1].strip() def chatbot(image, query, history): if image is not None: response = analyze_image(Image.fromarray(image), query) else: response = generate_text(query, history) history.append((query, response)) return "", history, history with gr.Blocks() as demo: gr.Markdown("# Multi-Modal AI Assistant") with gr.Row(): image_input = gr.Image(type="numpy", label="Upload an image (optional)") chat_history = gr.Chatbot(label="Chat History") query_input = gr.Textbox(label="Ask a question or enter a prompt") submit_button = gr.Button("Submit") state = gr.State([]) submit_button.click( chatbot, inputs=[image_input, query_input, state], outputs=[query_input, chat_history, state] ) demo.launch()