Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,9 +31,7 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
| 31 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
| 32 |
from diffusers.utils import export_to_ply
|
| 33 |
|
| 34 |
-
# -----------------------------------------------------------------------------
|
| 35 |
# Global constants and helper functions
|
| 36 |
-
# -----------------------------------------------------------------------------
|
| 37 |
|
| 38 |
MAX_SEED = np.iinfo(np.int32).max
|
| 39 |
|
|
@@ -52,9 +50,7 @@ def glb_to_data_url(glb_path: str) -> str:
|
|
| 52 |
b64_data = base64.b64encode(data).decode("utf-8")
|
| 53 |
return f"data:model/gltf-binary;base64,{b64_data}"
|
| 54 |
|
| 55 |
-
# -----------------------------------------------------------------------------
|
| 56 |
# Model class for Text-to-3D Generation (ShapE)
|
| 57 |
-
# -----------------------------------------------------------------------------
|
| 58 |
|
| 59 |
class Model:
|
| 60 |
def __init__(self):
|
|
@@ -113,9 +109,7 @@ class Model:
|
|
| 113 |
export_to_ply(images[0], ply_path.name)
|
| 114 |
return self.to_glb(ply_path.name)
|
| 115 |
|
| 116 |
-
# -----------------------------------------------------------------------------
|
| 117 |
# New Tools for Web Functionality using DuckDuckGo and smolagents
|
| 118 |
-
# -----------------------------------------------------------------------------
|
| 119 |
|
| 120 |
from typing import Any, Optional
|
| 121 |
from smolagents.tools import Tool
|
|
@@ -186,14 +180,68 @@ class VisitWebpageTool(Tool):
|
|
| 186 |
return f"Error fetching the webpage: {str(e)}"
|
| 187 |
except Exception as e:
|
| 188 |
return f"An unexpected error occurred: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
# -----------------------------------------------------------------------------
|
| 191 |
# Gradio UI configuration
|
| 192 |
-
# -----------------------------------------------------------------------------
|
| 193 |
|
| 194 |
DESCRIPTION = """
|
| 195 |
-
# Agent Dino 🌠
|
| 196 |
-
"""
|
| 197 |
|
| 198 |
css = '''
|
| 199 |
h1 {
|
|
@@ -215,11 +263,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
| 215 |
|
| 216 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 217 |
|
| 218 |
-
# -----------------------------------------------------------------------------
|
| 219 |
# Load Models and Pipelines for Chat, Image, and Multimodal Processing
|
| 220 |
-
# -----------------------------------------------------------------------------
|
| 221 |
-
|
| 222 |
# Load the text-only model and tokenizer (for pure text chat)
|
|
|
|
| 223 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
| 224 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 225 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -244,9 +290,7 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
| 244 |
torch_dtype=torch.float16
|
| 245 |
).to("cuda").eval()
|
| 246 |
|
| 247 |
-
# -----------------------------------------------------------------------------
|
| 248 |
# Asynchronous text-to-speech
|
| 249 |
-
# -----------------------------------------------------------------------------
|
| 250 |
|
| 251 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
| 252 |
"""Convert text to speech using Edge TTS and save as MP3"""
|
|
@@ -254,9 +298,7 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
|
| 254 |
await communicate.save(output_file)
|
| 255 |
return output_file
|
| 256 |
|
| 257 |
-
# -----------------------------------------------------------------------------
|
| 258 |
# Utility function to clean conversation history
|
| 259 |
-
# -----------------------------------------------------------------------------
|
| 260 |
|
| 261 |
def clean_chat_history(chat_history):
|
| 262 |
"""
|
|
@@ -269,9 +311,7 @@ def clean_chat_history(chat_history):
|
|
| 269 |
cleaned.append(msg)
|
| 270 |
return cleaned
|
| 271 |
|
| 272 |
-
# -----------------------------------------------------------------------------
|
| 273 |
# Stable Diffusion XL Pipeline for Image Generation
|
| 274 |
-
# -----------------------------------------------------------------------------
|
| 275 |
|
| 276 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
| 277 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
|
@@ -350,9 +390,7 @@ def generate_image_fn(
|
|
| 350 |
image_paths = [save_image(img) for img in images]
|
| 351 |
return image_paths, seed
|
| 352 |
|
| 353 |
-
# -----------------------------------------------------------------------------
|
| 354 |
# Text-to-3D Generation using the ShapE Pipeline
|
| 355 |
-
# -----------------------------------------------------------------------------
|
| 356 |
|
| 357 |
@spaces.GPU(duration=120, enable_queue=True)
|
| 358 |
def generate_3d_fn(
|
|
@@ -371,9 +409,7 @@ def generate_3d_fn(
|
|
| 371 |
glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
|
| 372 |
return glb_path, seed
|
| 373 |
|
| 374 |
-
#
|
| 375 |
-
# Chat Generation Function with support for @tts, @image, @3d, and now @web commands
|
| 376 |
-
# -----------------------------------------------------------------------------
|
| 377 |
|
| 378 |
@spaces.GPU
|
| 379 |
def generate(
|
|
@@ -386,14 +422,12 @@ def generate(
|
|
| 386 |
repetition_penalty: float = 1.2,
|
| 387 |
):
|
| 388 |
"""
|
| 389 |
-
Generates chatbot responses with support for multimodal input
|
| 390 |
-
3D model generation, and web search/visit.
|
| 391 |
-
|
| 392 |
-
Special commands:
|
| 393 |
- "@tts1" or "@tts2": triggers text-to-speech.
|
| 394 |
- "@image": triggers image generation using the SDXL pipeline.
|
| 395 |
- "@3d": triggers 3D model generation using the ShapE pipeline.
|
| 396 |
-
- "@web": triggers a web search or webpage visit.
|
|
|
|
| 397 |
"""
|
| 398 |
text = input_dict["text"]
|
| 399 |
files = input_dict.get("files", [])
|
|
@@ -401,7 +435,7 @@ def generate(
|
|
| 401 |
# --- 3D Generation branch ---
|
| 402 |
if text.strip().lower().startswith("@3d"):
|
| 403 |
prompt = text[len("@3d"):].strip()
|
| 404 |
-
yield "Hold tight, generating a 3D mesh GLB file....."
|
| 405 |
glb_path, used_seed = generate_3d_fn(
|
| 406 |
prompt=prompt,
|
| 407 |
seed=1,
|
|
@@ -423,7 +457,7 @@ def generate(
|
|
| 423 |
# --- Image Generation branch ---
|
| 424 |
if text.strip().lower().startswith("@image"):
|
| 425 |
prompt = text[len("@image"):].strip()
|
| 426 |
-
yield "Generating image..."
|
| 427 |
image_paths, used_seed = generate_image_fn(
|
| 428 |
prompt=prompt,
|
| 429 |
negative_prompt="",
|
|
@@ -446,19 +480,28 @@ def generate(
|
|
| 446 |
# If the command starts with "visit", then treat the rest as a URL
|
| 447 |
if web_command.lower().startswith("visit"):
|
| 448 |
url = web_command[len("visit"):].strip()
|
| 449 |
-
yield "Visiting webpage..."
|
| 450 |
visitor = VisitWebpageTool()
|
| 451 |
content = visitor.forward(url)
|
| 452 |
yield content
|
| 453 |
else:
|
| 454 |
# Otherwise, treat the rest as a search query.
|
| 455 |
query = web_command
|
| 456 |
-
yield "
|
| 457 |
searcher = DuckDuckGoSearchTool()
|
| 458 |
results = searcher.forward(query)
|
| 459 |
yield results
|
| 460 |
return
|
| 461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
# --- Text and TTS branch ---
|
| 463 |
tts_prefix = "@tts"
|
| 464 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
@@ -496,7 +539,7 @@ def generate(
|
|
| 496 |
thread.start()
|
| 497 |
|
| 498 |
buffer = ""
|
| 499 |
-
yield "Thinking..."
|
| 500 |
for new_text in streamer:
|
| 501 |
buffer += new_text
|
| 502 |
buffer = buffer.replace("<|im_end|>", "")
|
|
@@ -535,9 +578,7 @@ def generate(
|
|
| 535 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
| 536 |
yield gr.Audio(output_file, autoplay=True)
|
| 537 |
|
| 538 |
-
# -----------------------------------------------------------------------------
|
| 539 |
# Gradio Chat Interface Setup and Launch
|
| 540 |
-
# -----------------------------------------------------------------------------
|
| 541 |
|
| 542 |
demo = gr.ChatInterface(
|
| 543 |
fn=generate,
|
|
@@ -553,8 +594,9 @@ demo = gr.ChatInterface(
|
|
| 553 |
["@3d A birthday cupcake with cherry"],
|
| 554 |
[{"text": "summarize the letter", "files": ["examples/1.png"]}],
|
| 555 |
["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
|
| 556 |
-
["
|
| 557 |
["@web latest breakthroughs in renewable energy"],
|
|
|
|
| 558 |
],
|
| 559 |
cache_examples=False,
|
| 560 |
type="messages",
|
|
@@ -570,10 +612,8 @@ demo = gr.ChatInterface(
|
|
| 570 |
if not os.path.exists("static"):
|
| 571 |
os.makedirs("static")
|
| 572 |
|
| 573 |
-
# Mount the static folder onto the FastAPI app so that GLB files are served.
|
| 574 |
from fastapi.staticfiles import StaticFiles
|
| 575 |
demo.app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 576 |
|
| 577 |
if __name__ == "__main__":
|
| 578 |
-
# Launch without the unsupported static_dirs parameter.
|
| 579 |
demo.queue(max_size=20).launch(share=True)
|
|
|
|
| 31 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
| 32 |
from diffusers.utils import export_to_ply
|
| 33 |
|
|
|
|
| 34 |
# Global constants and helper functions
|
|
|
|
| 35 |
|
| 36 |
MAX_SEED = np.iinfo(np.int32).max
|
| 37 |
|
|
|
|
| 50 |
b64_data = base64.b64encode(data).decode("utf-8")
|
| 51 |
return f"data:model/gltf-binary;base64,{b64_data}"
|
| 52 |
|
|
|
|
| 53 |
# Model class for Text-to-3D Generation (ShapE)
|
|
|
|
| 54 |
|
| 55 |
class Model:
|
| 56 |
def __init__(self):
|
|
|
|
| 109 |
export_to_ply(images[0], ply_path.name)
|
| 110 |
return self.to_glb(ply_path.name)
|
| 111 |
|
|
|
|
| 112 |
# New Tools for Web Functionality using DuckDuckGo and smolagents
|
|
|
|
| 113 |
|
| 114 |
from typing import Any, Optional
|
| 115 |
from smolagents.tools import Tool
|
|
|
|
| 180 |
return f"Error fetching the webpage: {str(e)}"
|
| 181 |
except Exception as e:
|
| 182 |
return f"An unexpected error occurred: {str(e)}"
|
| 183 |
+
|
| 184 |
+
# New Feature: rAgent Reasoning using Llama mode OpenAI
|
| 185 |
+
|
| 186 |
+
from openai import OpenAI
|
| 187 |
+
|
| 188 |
+
ACCESS_TOKEN = os.getenv("HF_TOKEN")
|
| 189 |
+
ragent_client = OpenAI(
|
| 190 |
+
base_url="https://api-inference.huggingface.co/v1/",
|
| 191 |
+
api_key=ACCESS_TOKEN,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
| 195 |
+
|
| 196 |
+
To do so, you must follow a structured reasoning process in a cycle of:
|
| 197 |
+
|
| 198 |
+
1. **Thought:**
|
| 199 |
+
- Analyze the problem and explain your reasoning.
|
| 200 |
+
- Identify any necessary tools or techniques.
|
| 201 |
+
|
| 202 |
+
2. **Code:**
|
| 203 |
+
- Implement the solution using Python.
|
| 204 |
+
- Enclose the code block with `<end_code>`.
|
| 205 |
+
|
| 206 |
+
3. **Observation:**
|
| 207 |
+
- Explain the output and verify correctness.
|
| 208 |
+
|
| 209 |
+
4. **Final Answer:**
|
| 210 |
+
- Summarize the solution clearly.
|
| 211 |
+
|
| 212 |
+
Always adhere to the **Thought → Code → Observation → Final Answer** structure.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95):
|
| 216 |
+
"""
|
| 217 |
+
Uses the Llama mode OpenAI model to perform a structured reasoning chain.
|
| 218 |
+
"""
|
| 219 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 220 |
+
# Incorporate conversation history (if any)
|
| 221 |
+
for msg in history:
|
| 222 |
+
if msg.get("role") == "user":
|
| 223 |
+
messages.append({"role": "user", "content": msg["content"]})
|
| 224 |
+
elif msg.get("role") == "assistant":
|
| 225 |
+
messages.append({"role": "assistant", "content": msg["content"]})
|
| 226 |
+
messages.append({"role": "user", "content": prompt})
|
| 227 |
+
response = ""
|
| 228 |
+
stream = ragent_client.chat.completions.create(
|
| 229 |
+
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
| 230 |
+
max_tokens=max_tokens,
|
| 231 |
+
stream=True,
|
| 232 |
+
temperature=temperature,
|
| 233 |
+
top_p=top_p,
|
| 234 |
+
messages=messages,
|
| 235 |
+
)
|
| 236 |
+
for message in stream:
|
| 237 |
+
token = message.choices[0].delta.content
|
| 238 |
+
response += token
|
| 239 |
+
yield response
|
| 240 |
|
|
|
|
| 241 |
# Gradio UI configuration
|
|
|
|
| 242 |
|
| 243 |
DESCRIPTION = """
|
| 244 |
+
# Agent Dino 🌠 """
|
|
|
|
| 245 |
|
| 246 |
css = '''
|
| 247 |
h1 {
|
|
|
|
| 263 |
|
| 264 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 265 |
|
|
|
|
| 266 |
# Load Models and Pipelines for Chat, Image, and Multimodal Processing
|
|
|
|
|
|
|
| 267 |
# Load the text-only model and tokenizer (for pure text chat)
|
| 268 |
+
|
| 269 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
| 270 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 271 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 290 |
torch_dtype=torch.float16
|
| 291 |
).to("cuda").eval()
|
| 292 |
|
|
|
|
| 293 |
# Asynchronous text-to-speech
|
|
|
|
| 294 |
|
| 295 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
| 296 |
"""Convert text to speech using Edge TTS and save as MP3"""
|
|
|
|
| 298 |
await communicate.save(output_file)
|
| 299 |
return output_file
|
| 300 |
|
|
|
|
| 301 |
# Utility function to clean conversation history
|
|
|
|
| 302 |
|
| 303 |
def clean_chat_history(chat_history):
|
| 304 |
"""
|
|
|
|
| 311 |
cleaned.append(msg)
|
| 312 |
return cleaned
|
| 313 |
|
|
|
|
| 314 |
# Stable Diffusion XL Pipeline for Image Generation
|
|
|
|
| 315 |
|
| 316 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
| 317 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
|
|
|
| 390 |
image_paths = [save_image(img) for img in images]
|
| 391 |
return image_paths, seed
|
| 392 |
|
|
|
|
| 393 |
# Text-to-3D Generation using the ShapE Pipeline
|
|
|
|
| 394 |
|
| 395 |
@spaces.GPU(duration=120, enable_queue=True)
|
| 396 |
def generate_3d_fn(
|
|
|
|
| 409 |
glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
|
| 410 |
return glb_path, seed
|
| 411 |
|
| 412 |
+
# Chat Generation Function with support for @tts, @image, @3d, @web, and @rAgent commands
|
|
|
|
|
|
|
| 413 |
|
| 414 |
@spaces.GPU
|
| 415 |
def generate(
|
|
|
|
| 422 |
repetition_penalty: float = 1.2,
|
| 423 |
):
|
| 424 |
"""
|
| 425 |
+
Generates chatbot responses with support for multimodal input and special commands:
|
|
|
|
|
|
|
|
|
|
| 426 |
- "@tts1" or "@tts2": triggers text-to-speech.
|
| 427 |
- "@image": triggers image generation using the SDXL pipeline.
|
| 428 |
- "@3d": triggers 3D model generation using the ShapE pipeline.
|
| 429 |
+
- "@web": triggers a web search or webpage visit.
|
| 430 |
+
- "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
|
| 431 |
"""
|
| 432 |
text = input_dict["text"]
|
| 433 |
files = input_dict.get("files", [])
|
|
|
|
| 435 |
# --- 3D Generation branch ---
|
| 436 |
if text.strip().lower().startswith("@3d"):
|
| 437 |
prompt = text[len("@3d"):].strip()
|
| 438 |
+
yield "🌀 Hold tight, generating a 3D mesh GLB file....."
|
| 439 |
glb_path, used_seed = generate_3d_fn(
|
| 440 |
prompt=prompt,
|
| 441 |
seed=1,
|
|
|
|
| 457 |
# --- Image Generation branch ---
|
| 458 |
if text.strip().lower().startswith("@image"):
|
| 459 |
prompt = text[len("@image"):].strip()
|
| 460 |
+
yield "🪧 Generating image..."
|
| 461 |
image_paths, used_seed = generate_image_fn(
|
| 462 |
prompt=prompt,
|
| 463 |
negative_prompt="",
|
|
|
|
| 480 |
# If the command starts with "visit", then treat the rest as a URL
|
| 481 |
if web_command.lower().startswith("visit"):
|
| 482 |
url = web_command[len("visit"):].strip()
|
| 483 |
+
yield "🌍 Visiting webpage..."
|
| 484 |
visitor = VisitWebpageTool()
|
| 485 |
content = visitor.forward(url)
|
| 486 |
yield content
|
| 487 |
else:
|
| 488 |
# Otherwise, treat the rest as a search query.
|
| 489 |
query = web_command
|
| 490 |
+
yield "🧤 Performing a web search ..."
|
| 491 |
searcher = DuckDuckGoSearchTool()
|
| 492 |
results = searcher.forward(query)
|
| 493 |
yield results
|
| 494 |
return
|
| 495 |
|
| 496 |
+
# --- rAgent Reasoning branch ---
|
| 497 |
+
if text.strip().lower().startswith("@ragent"):
|
| 498 |
+
prompt = text[len("@ragent"):].strip()
|
| 499 |
+
yield "Initiating reasoning chain using Llama mode..."
|
| 500 |
+
# Pass the current chat history (cleaned) to help inform the chain.
|
| 501 |
+
for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
|
| 502 |
+
yield partial
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
# --- Text and TTS branch ---
|
| 506 |
tts_prefix = "@tts"
|
| 507 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
|
|
| 539 |
thread.start()
|
| 540 |
|
| 541 |
buffer = ""
|
| 542 |
+
yield "🤔 Thinking..."
|
| 543 |
for new_text in streamer:
|
| 544 |
buffer += new_text
|
| 545 |
buffer = buffer.replace("<|im_end|>", "")
|
|
|
|
| 578 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
| 579 |
yield gr.Audio(output_file, autoplay=True)
|
| 580 |
|
|
|
|
| 581 |
# Gradio Chat Interface Setup and Launch
|
|
|
|
| 582 |
|
| 583 |
demo = gr.ChatInterface(
|
| 584 |
fn=generate,
|
|
|
|
| 594 |
["@3d A birthday cupcake with cherry"],
|
| 595 |
[{"text": "summarize the letter", "files": ["examples/1.png"]}],
|
| 596 |
["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
|
| 597 |
+
["@rAgent Explain how a binary search algorithm works."],
|
| 598 |
["@web latest breakthroughs in renewable energy"],
|
| 599 |
+
|
| 600 |
],
|
| 601 |
cache_examples=False,
|
| 602 |
type="messages",
|
|
|
|
| 612 |
if not os.path.exists("static"):
|
| 613 |
os.makedirs("static")
|
| 614 |
|
|
|
|
| 615 |
from fastapi.staticfiles import StaticFiles
|
| 616 |
demo.app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 617 |
|
| 618 |
if __name__ == "__main__":
|
|
|
|
| 619 |
demo.queue(max_size=20).launch(share=True)
|