Spaces:
Running
on
A10G
Running
on
A10G
Update app.py
#7
by
hjbfd
- opened
app.py
CHANGED
|
@@ -1,839 +1,182 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
import argparse
|
| 5 |
-
import time
|
| 6 |
-
import uuid
|
| 7 |
-
import subprocess
|
| 8 |
-
import requests
|
| 9 |
-
from typing import List, Dict, Any, Iterator
|
| 10 |
-
|
| 11 |
-
from dotenv import load_dotenv
|
| 12 |
-
load_dotenv()
|
| 13 |
-
|
| 14 |
import gradio as gr
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
from agentflow.models.memory import Memory
|
| 21 |
-
from agentflow.models.executor import Executor
|
| 22 |
-
from agentflow.models.utils import make_json_serializable_truncated
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
from pathlib import Path
|
| 26 |
-
from huggingface_hub import CommitScheduler
|
| 27 |
-
|
| 28 |
-
# Get Huggingface token from environment variable
|
| 29 |
-
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
| 30 |
-
|
| 31 |
-
########### Test Huggingface Dataset ###########
|
| 32 |
-
# Update the HuggingFace dataset constants
|
| 33 |
-
DATASET_DIR = Path("solver_cache") # the directory to save the dataset
|
| 34 |
-
DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
| 35 |
-
|
| 36 |
-
global QUERY_ID
|
| 37 |
-
QUERY_ID = None
|
| 38 |
-
|
| 39 |
-
TOOL_NAME_MAPPING = {
|
| 40 |
-
"Generalist_Solution_Generator_Tool": "Base_Generator_Tool",
|
| 41 |
-
"Ground_Google_Search_Tool": "Google_Search_Tool",
|
| 42 |
-
"Python_Code_Generator_Tool": "Python_Coder_Tool",
|
| 43 |
-
"Web_RAG_Search_Tool": "Web_Search_Tool",
|
| 44 |
-
"Wikipedia_RAG_Search_Tool": "Wikipedia_Search_Tool"
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
# Enable scheduler to record data to HuggingFace dataset
|
| 48 |
-
# scheduler = None
|
| 49 |
-
scheduler = CommitScheduler(
|
| 50 |
-
repo_id="ZhuofengLi/AgentFlow-Gradio-Demo-User-Data",
|
| 51 |
-
repo_type="dataset",
|
| 52 |
-
folder_path=DATASET_DIR,
|
| 53 |
-
path_in_repo="solver_cache", # Update path in repo
|
| 54 |
-
token=HF_TOKEN
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
########### vLLM Service Management ###########
|
| 58 |
-
VLLM_MODEL_NAME = "AgentFlow/agentflow-planner-7b"
|
| 59 |
-
VLLM_PORT = 8000
|
| 60 |
-
VLLM_HOST = "localhost"
|
| 61 |
-
VLLM_PROCESS = None
|
| 62 |
-
|
| 63 |
-
def check_vllm_service() -> bool:
|
| 64 |
-
"""Check if vLLM service is running"""
|
| 65 |
-
try:
|
| 66 |
-
response = requests.get(f"http://{VLLM_HOST}:{VLLM_PORT}/v1/models", timeout=2)
|
| 67 |
-
return response.status_code == 200
|
| 68 |
-
except:
|
| 69 |
-
return False
|
| 70 |
-
|
| 71 |
-
def start_vllm_service() -> bool:
|
| 72 |
-
"""Start vLLM service in background"""
|
| 73 |
-
global VLLM_PROCESS
|
| 74 |
-
|
| 75 |
-
if check_vllm_service():
|
| 76 |
-
print(f"🟢 vLLM service already running on port {VLLM_PORT}")
|
| 77 |
-
return True
|
| 78 |
-
|
| 79 |
-
try:
|
| 80 |
-
print(f"🚀 Starting vLLM service for {VLLM_MODEL_NAME}...")
|
| 81 |
-
|
| 82 |
-
# Start vLLM server in background
|
| 83 |
-
VLLM_PROCESS = subprocess.Popen(
|
| 84 |
-
[
|
| 85 |
-
"vllm", "serve", VLLM_MODEL_NAME,
|
| 86 |
-
"--port", str(VLLM_PORT),
|
| 87 |
-
"--host", VLLM_HOST
|
| 88 |
-
],
|
| 89 |
-
text=True
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
# Wait for service to be ready (max 60 seconds)
|
| 93 |
-
for i in range(180):
|
| 94 |
-
time.sleep(1)
|
| 95 |
-
if check_vllm_service():
|
| 96 |
-
print(f"🟢 vLLM service started successfully on port {VLLM_PORT}")
|
| 97 |
-
return True
|
| 98 |
-
|
| 99 |
-
print("⚠️ vLLM service failed to start within 60 seconds")
|
| 100 |
-
return False
|
| 101 |
-
|
| 102 |
-
except Exception as e:
|
| 103 |
-
print(f"❌ Failed to start vLLM service: {e}")
|
| 104 |
-
return False
|
| 105 |
-
|
| 106 |
-
def stop_vllm_service():
|
| 107 |
-
"""Stop vLLM service if running"""
|
| 108 |
-
global VLLM_PROCESS
|
| 109 |
-
if VLLM_PROCESS:
|
| 110 |
-
VLLM_PROCESS.terminate()
|
| 111 |
-
VLLM_PROCESS.wait()
|
| 112 |
-
print("🛑 vLLM service stopped")
|
| 113 |
-
|
| 114 |
-
def get_vllm_status() -> str:
|
| 115 |
-
"""Get vLLM service status message"""
|
| 116 |
-
if check_vllm_service():
|
| 117 |
-
return f"🟢 vLLM service running on port {VLLM_PORT}"
|
| 118 |
-
else:
|
| 119 |
-
return f"⚠️ vLLM service not running"
|
| 120 |
-
|
| 121 |
-
########### End of vLLM Service Management ###########
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def save_query_data(query_id: str, query: str) -> None:
|
| 125 |
-
"""Save query data to dataset"""
|
| 126 |
-
# Save query metadata
|
| 127 |
-
query_cache_dir = DATASET_DIR / query_id
|
| 128 |
-
query_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 129 |
-
query_file = query_cache_dir / "query_metadata.json"
|
| 130 |
-
|
| 131 |
-
query_metadata = {
|
| 132 |
-
"query_id": query_id,
|
| 133 |
-
"query_text": query,
|
| 134 |
-
"datetime": time.strftime("%Y%m%d_%H%M%S"),
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
print(f"Saving query metadata to {query_file}")
|
| 138 |
-
with query_file.open("w") as f:
|
| 139 |
-
json.dump(query_metadata, f, indent=4)
|
| 140 |
-
|
| 141 |
|
| 142 |
-
def
|
| 143 |
"""
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
Args:
|
| 147 |
-
query_id: Unique identifier for the query
|
| 148 |
-
feedback_type: Type of feedback ('upvote', 'downvote', or 'comment')
|
| 149 |
-
feedback_text: Optional text feedback from user
|
| 150 |
"""
|
| 151 |
-
|
| 152 |
-
feedback_data_dir = DATASET_DIR / query_id
|
| 153 |
-
feedback_data_dir.mkdir(parents=True, exist_ok=True)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
"datetime": time.strftime("%Y%m%d_%H%M%S")
|
| 160 |
-
}
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
with feedback_file.open("r") as f:
|
| 169 |
-
existing_feedback = json.load(f)
|
| 170 |
-
# Convert to list if it's a single feedback entry
|
| 171 |
-
if not isinstance(existing_feedback, list):
|
| 172 |
-
existing_feedback = [existing_feedback]
|
| 173 |
-
existing_feedback.append(feedback_data)
|
| 174 |
-
feedback_data = existing_feedback
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
def save_module_data(query_id: str, key: str, value: Any) -> None:
|
| 194 |
-
"""Save module data to Huggingface dataset"""
|
| 195 |
try:
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
except Exception as e:
|
| 210 |
-
print(f"Error: Failed to save as text file: {e}")
|
| 211 |
-
|
| 212 |
-
########### End of Test Huggingface Dataset ###########
|
| 213 |
-
|
| 214 |
-
class Solver:
|
| 215 |
-
def __init__(
|
| 216 |
-
self,
|
| 217 |
-
planner,
|
| 218 |
-
memory,
|
| 219 |
-
executor,
|
| 220 |
-
output_types: str = "base,final,direct",
|
| 221 |
-
index: int = 0,
|
| 222 |
-
verbose: bool = True,
|
| 223 |
-
max_steps: int = 10,
|
| 224 |
-
max_time: int = 60,
|
| 225 |
-
query_cache_dir: str = "solver_cache"
|
| 226 |
-
):
|
| 227 |
-
self.planner = planner
|
| 228 |
-
self.memory = memory
|
| 229 |
-
self.executor = executor
|
| 230 |
-
self.index = index
|
| 231 |
-
self.verbose = verbose
|
| 232 |
-
self.max_steps = max_steps
|
| 233 |
-
self.max_time = max_time
|
| 234 |
-
self.query_cache_dir = query_cache_dir
|
| 235 |
-
|
| 236 |
-
self.output_types = output_types.lower().split(',')
|
| 237 |
-
assert all(output_type in ["base", "final", "direct"] for output_type in self.output_types), "Invalid output type. Supported types are 'base', 'final', 'direct'."
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def stream_solve_user_problem(self, user_query: str, messages: List[ChatMessage]) -> Iterator[List[ChatMessage]]:
|
| 241 |
-
"""
|
| 242 |
-
Streams intermediate thoughts and final responses for the problem-solving process based on user input.
|
| 243 |
-
|
| 244 |
-
Args:
|
| 245 |
-
user_query (str): The text query input from the user.
|
| 246 |
-
messages (list): A list of ChatMessage objects to store the streamed responses.
|
| 247 |
"""
|
| 248 |
-
|
| 249 |
-
img_path = None # AgentFlow doesn't use images in this demo
|
| 250 |
-
|
| 251 |
-
# Set tool cache directory
|
| 252 |
-
_tool_cache_dir = os.path.join(self.query_cache_dir, "tool_cache") # NOTE: This is the directory for tool cache
|
| 253 |
-
self.executor.set_query_cache_dir(_tool_cache_dir) # NOTE: set query cache directory
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
# # Step 2: Add "thinking" status while processing
|
| 260 |
-
# messages.append(ChatMessage(
|
| 261 |
-
# role="assistant",
|
| 262 |
-
# content="",
|
| 263 |
-
# metadata={"title": "⏳ Thinking: Processing input..."}
|
| 264 |
-
# ))
|
| 265 |
-
|
| 266 |
-
# [Step 3] Initialize problem-solving state
|
| 267 |
-
start_time = time.time()
|
| 268 |
-
step_count = 0
|
| 269 |
-
json_data = {"query": user_query, "image": "Image received as bytes"}
|
| 270 |
-
|
| 271 |
-
messages.append(ChatMessage(role="assistant", content="<br>"))
|
| 272 |
-
messages.append(ChatMessage(role="assistant", content="### 🧠 Reasoning Steps from AgentFlow (Deep Reasoning...)"))
|
| 273 |
-
yield messages
|
| 274 |
-
|
| 275 |
-
# [Step 4] Query Analysis
|
| 276 |
-
query_analysis = self.planner.analyze_query(user_query, img_path)
|
| 277 |
-
json_data["query_analysis"] = query_analysis # TODO: update
|
| 278 |
-
|
| 279 |
-
# Format the query analysis for display
|
| 280 |
-
query_analysis_display = query_analysis.replace("Concise Summary:", "**Concise Summary:**\n")
|
| 281 |
-
query_analysis_display = query_analysis_display.replace("Required Skills:", "**Required Skills:**")
|
| 282 |
-
query_analysis_display = query_analysis_display.replace("Relevant Tools:", "**Relevant Tools:**")
|
| 283 |
-
query_analysis_display = query_analysis_display.replace("Additional Considerations:", "**Additional Considerations:**")
|
| 284 |
-
|
| 285 |
-
# Map tool names in query analysis for display
|
| 286 |
-
for original_name, display_name in TOOL_NAME_MAPPING.items():
|
| 287 |
-
query_analysis_display = query_analysis_display.replace(original_name, display_name)
|
| 288 |
-
|
| 289 |
-
messages.append(ChatMessage(role="assistant",
|
| 290 |
-
content=f"{query_analysis_display}",
|
| 291 |
-
metadata={"title": "### 🔎 Step 0: Query Analysis"}))
|
| 292 |
-
yield messages
|
| 293 |
-
|
| 294 |
-
# Save the query analysis data
|
| 295 |
-
query_analysis_data = {
|
| 296 |
-
"query_analysis": query_analysis,
|
| 297 |
-
"time": round(time.time() - start_time, 5)
|
| 298 |
-
}
|
| 299 |
-
save_module_data(QUERY_ID, "step_0_query_analysis", query_analysis_data)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
# Execution loop (similar to your step-by-step solver)
|
| 304 |
-
while step_count < self.max_steps and (time.time() - start_time) < self.max_time:
|
| 305 |
-
step_count += 1
|
| 306 |
-
messages.append(ChatMessage(role="AgentFlow",
|
| 307 |
-
content=f"Generating the {step_count}-th step...",
|
| 308 |
-
metadata={"title": f"🔄 Step {step_count}"}))
|
| 309 |
-
yield messages
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
# Display the step information
|
| 326 |
-
display_tool_name = TOOL_NAME_MAPPING.get(tool_name, tool_name)
|
| 327 |
-
|
| 328 |
-
# Map tool names in context and sub_goal for display
|
| 329 |
-
context_display = context if context else ""
|
| 330 |
-
sub_goal_display = sub_goal if sub_goal else ""
|
| 331 |
-
for original_name, display_name in TOOL_NAME_MAPPING.items():
|
| 332 |
-
context_display = context_display.replace(original_name, display_name)
|
| 333 |
-
sub_goal_display = sub_goal_display.replace(original_name, display_name)
|
| 334 |
-
|
| 335 |
-
messages.append(ChatMessage(
|
| 336 |
-
role="assistant",
|
| 337 |
-
content=f"**Context:** {context_display}\n\n**Sub-goal:** {sub_goal_display}\n\n**Tool:** `{display_tool_name}`",
|
| 338 |
-
metadata={"title": f"### 🎯 Step {step_count}: Action Prediction ({display_tool_name})"}))
|
| 339 |
-
yield messages
|
| 340 |
-
|
| 341 |
-
# Handle tool execution or errors
|
| 342 |
-
if tool_name not in self.planner.available_tools:
|
| 343 |
-
display_tool_name = TOOL_NAME_MAPPING.get(tool_name, tool_name)
|
| 344 |
-
messages.append(ChatMessage(
|
| 345 |
-
role="assistant",
|
| 346 |
-
content=f"⚠️ Error: Tool '{display_tool_name}' is not available."))
|
| 347 |
-
yield messages
|
| 348 |
-
continue
|
| 349 |
-
|
| 350 |
-
# [Step 6-7] Generate and execute the tool command
|
| 351 |
-
tool_command = self.executor.generate_tool_command(
|
| 352 |
-
user_query, img_path, context, sub_goal, tool_name, self.planner.toolbox_metadata[tool_name], step_count, json_data
|
| 353 |
)
|
| 354 |
-
analysis, explanation, command = self.executor.extract_explanation_and_command(tool_command)
|
| 355 |
-
result = self.executor.execute_tool_command(tool_name, command)
|
| 356 |
-
result = make_json_serializable_truncated(result)
|
| 357 |
-
|
| 358 |
-
# Display the ommand generation information
|
| 359 |
-
display_tool_name = TOOL_NAME_MAPPING.get(tool_name, tool_name)
|
| 360 |
-
messages.append(ChatMessage(
|
| 361 |
-
role="assistant",
|
| 362 |
-
content=f"**Command:**\n```python\n{command}\n```",
|
| 363 |
-
metadata={"title": f"### 📋 Step {step_count}: Command Generation ({display_tool_name})"}))
|
| 364 |
-
yield messages
|
| 365 |
-
|
| 366 |
-
# Save the command generation data
|
| 367 |
-
command_generation_data = {
|
| 368 |
-
"analysis": analysis,
|
| 369 |
-
"explanation": explanation,
|
| 370 |
-
"command": command,
|
| 371 |
-
"time": round(time.time() - start_time, 5)
|
| 372 |
-
}
|
| 373 |
-
save_module_data(QUERY_ID, f"step_{step_count}_command_generation", command_generation_data)
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
# Save the command execution data
|
| 391 |
-
command_execution_data = {
|
| 392 |
-
"result": result,
|
| 393 |
-
"time": round(time.time() - start_time, 5)
|
| 394 |
-
}
|
| 395 |
-
save_module_data(QUERY_ID, f"step_{step_count}_command_execution", command_execution_data)
|
| 396 |
-
|
| 397 |
-
# [Step 8] Memory update and stopping condition
|
| 398 |
-
self.memory.add_action(step_count, tool_name, sub_goal, command, result) # TODO: do not update here
|
| 399 |
-
stop_verification = self.planner.verificate_context(user_query, img_path, query_analysis, self.memory, step_count, json_data)
|
| 400 |
-
context_verification, conclusion = self.planner.extract_conclusion(stop_verification)
|
| 401 |
-
|
| 402 |
-
# Save the context verification data
|
| 403 |
-
context_verification_data = {
|
| 404 |
-
"stop_verification": context_verification,
|
| 405 |
-
"conclusion": conclusion,
|
| 406 |
-
"time": round(time.time() - start_time, 5)
|
| 407 |
-
}
|
| 408 |
-
save_module_data(QUERY_ID, f"step_{step_count}_context_verification", context_verification_data)
|
| 409 |
-
|
| 410 |
-
# Display the context verification result # TODO: update context_verification
|
| 411 |
-
# Map tool names in context verification for display
|
| 412 |
-
context_verification_display = context_verification if context_verification else ""
|
| 413 |
-
for original_name, display_name in TOOL_NAME_MAPPING.items():
|
| 414 |
-
context_verification_display = context_verification_display.replace(original_name, display_name)
|
| 415 |
-
|
| 416 |
-
conclusion_emoji = "✅" if conclusion == 'STOP' else "🛑"
|
| 417 |
-
messages.append(ChatMessage(
|
| 418 |
-
role="assistant",
|
| 419 |
-
content=f"**Analysis:**\n{context_verification_display}\n\n**Conclusion:** `{conclusion}` {conclusion_emoji}",
|
| 420 |
-
metadata={"title": f"### 🤖 Step {step_count}: Context Verification"}))
|
| 421 |
-
yield messages
|
| 422 |
-
|
| 423 |
-
if conclusion == 'STOP':
|
| 424 |
-
break
|
| 425 |
-
|
| 426 |
-
# Step 7: Generate Final Output (if needed)
|
| 427 |
-
if 'direct' in self.output_types:
|
| 428 |
-
messages.append(ChatMessage(role="assistant", content="<br>"))
|
| 429 |
-
direct_output = self.planner.generate_direct_output(user_query, img_path, self.memory) # TODO: update
|
| 430 |
-
|
| 431 |
-
# Map tool names in direct output for display
|
| 432 |
-
direct_output_display = direct_output if direct_output else ""
|
| 433 |
-
for original_name, display_name in TOOL_NAME_MAPPING.items():
|
| 434 |
-
direct_output_display = direct_output_display.replace(original_name, display_name)
|
| 435 |
-
|
| 436 |
-
messages.append(ChatMessage(role="assistant", content=f"### 🎉 Final Answer:\n{direct_output_display}"))
|
| 437 |
-
yield messages
|
| 438 |
-
|
| 439 |
-
# Save the direct output data
|
| 440 |
-
direct_output_data = {
|
| 441 |
-
"direct_output": direct_output,
|
| 442 |
-
"time": round(time.time() - start_time, 5)
|
| 443 |
-
}
|
| 444 |
-
save_module_data(QUERY_ID, "direct_output", direct_output_data)
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
if 'final' in self.output_types:
|
| 448 |
-
final_output = self.planner.generate_final_output(user_query, img_path, self.memory) # Disabled visibility for now
|
| 449 |
-
# messages.append(ChatMessage(role="assistant", content=f"🎯 Final Output:\n{final_output}"))
|
| 450 |
-
# yield messages
|
| 451 |
-
|
| 452 |
-
# Save the final output data
|
| 453 |
-
final_output_data = {
|
| 454 |
-
"final_output": final_output,
|
| 455 |
-
"time": round(time.time() - start_time, 5)
|
| 456 |
-
}
|
| 457 |
-
save_module_data(QUERY_ID, "final_output", final_output_data)
|
| 458 |
-
|
| 459 |
-
# Step 8: Completion Message
|
| 460 |
-
messages.append(ChatMessage(role="assistant", content="<br>"))
|
| 461 |
-
messages.append(ChatMessage(role="assistant", content="### ✨ Query Solved!"))
|
| 462 |
-
messages.append(ChatMessage(role="assistant", content="How do you like the output from AgentFlow 🌀💫? Please give us your feedback below. \n\n👍 If the answer is correct or the reasoning steps are helpful, please upvote the output. \n👎 If it is incorrect or the reasoning steps are not helpful, please downvote the output. \n💬 If you have any suggestions or comments, please leave them below.\n\nThank you for using AgentFlow! 🌀💫"))
|
| 463 |
-
yield messages
|
| 464 |
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
parser.add_argument("--openai_api_source", default="we_provided", choices=["we_provided", "user_provided"], help="Source of OpenAI API key.")
|
| 483 |
-
return parser.parse_args()
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
def solve_problem_gradio(user_query, max_steps=10, max_time=60, llm_model_engine=None, enabled_tools=None):
|
| 487 |
-
"""
|
| 488 |
-
Wrapper function to connect the solver to Gradio.
|
| 489 |
-
Streams responses from `solver.stream_solve_user_problem` for real-time UI updates.
|
| 490 |
-
"""
|
| 491 |
-
|
| 492 |
-
# Check if query is empty
|
| 493 |
-
if not user_query or not user_query.strip():
|
| 494 |
-
yield [ChatMessage(role="assistant", content="❌ Error: Please enter a question before submitting.")]
|
| 495 |
-
return
|
| 496 |
-
|
| 497 |
-
# Generate Unique Query ID (Date and first 8 characters of UUID)
|
| 498 |
-
query_id = time.strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8] # e.g, 20250217_062225_612f2474
|
| 499 |
-
print(f"Query ID: {query_id}")
|
| 500 |
-
|
| 501 |
-
# NOTE: update the global variable to save the query ID
|
| 502 |
-
global QUERY_ID
|
| 503 |
-
QUERY_ID = query_id
|
| 504 |
-
|
| 505 |
-
# Create a directory for the query ID
|
| 506 |
-
query_cache_dir = os.path.join(DATASET_DIR.name, query_id) # NOTE
|
| 507 |
-
os.makedirs(query_cache_dir, exist_ok=True)
|
| 508 |
-
|
| 509 |
-
# if api_key is None:
|
| 510 |
-
# return [["assistant", "❌ Error: OpenAI API Key is required."]]
|
| 511 |
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
)
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
# Instantiate Executor
|
| 543 |
-
executor = Executor(
|
| 544 |
-
llm_engine_name="dashscope", # AgentFlow uses dashscope for executor
|
| 545 |
-
root_cache_dir=query_cache_dir, # NOTE
|
| 546 |
-
verbose=False,
|
| 547 |
-
temperature=0.7,
|
| 548 |
-
enable_signal=False
|
| 549 |
-
)
|
| 550 |
-
|
| 551 |
-
# Instantiate Solver
|
| 552 |
-
solver = Solver(
|
| 553 |
-
planner=planner,
|
| 554 |
-
memory=memory,
|
| 555 |
-
executor=executor,
|
| 556 |
-
output_types=args.output_types, # Add new parameter
|
| 557 |
-
verbose=args.verbose,
|
| 558 |
-
max_steps=max_steps,
|
| 559 |
-
max_time=max_time,
|
| 560 |
-
query_cache_dir=query_cache_dir # NOTE
|
| 561 |
-
)
|
| 562 |
-
|
| 563 |
-
if solver is None:
|
| 564 |
-
return [["assistant", "❌ Error: Solver is not initialized. Please restart the application."]]
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
messages = [] # Initialize message list
|
| 568 |
-
for message_batch in solver.stream_solve_user_problem(user_query, messages):
|
| 569 |
-
yield [msg for msg in message_batch] # Ensure correct format for Gradio Chatbot
|
| 570 |
-
|
| 571 |
-
# Save steps
|
| 572 |
-
save_steps_data(
|
| 573 |
-
query_id=query_id,
|
| 574 |
-
memory=memory
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
def main(args):
|
| 579 |
-
#################### Gradio Interface ####################
|
| 580 |
-
# with gr.Blocks() as demo:
|
| 581 |
-
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
|
| 582 |
-
# Theming https://www.gradio.app/guides/theming-guide
|
| 583 |
-
|
| 584 |
-
gr.Markdown("# 🌀💫 Chat with AgentFlow: A Trainable Agentic Framework for Complex Reasoning") # Title
|
| 585 |
-
gr.Markdown("""
|
| 586 |
-
**AgentFlow** is a **trainable, tool-integrated agentic framework** designed to overcome the scalability and generalization limits of today's tool-augmented reasoning approaches. It introduces a **modular agentic system** (🧭 Planner, 🛠 Executor, ✅ Verifier, and ✍️ Generator) and an **in-the-flow RL algorithm (Flow-GRPO)** to optimize the agent within the system for **effective planning and tool use**.
|
| 587 |
-
|
| 588 |
-
[Website](https://agentflow.stanford.edu/) |
|
| 589 |
-
[HF Paper](https://huggingface.co/papers/2510.05592) |
|
| 590 |
-
[GitHub](https://github.com/lupantech/AgentFlow) |
|
| 591 |
-
[Model](https://huggingface.co/AgentFlow/agentflow-planner-7b) |
|
| 592 |
-
[YouTube](https://www.youtube.com/watch?v=kIQbCQIH1SI) |
|
| 593 |
-
[X (Twitter)](https://x.com/lupantech/status/1976016000345919803) |
|
| 594 |
-
[Slack](https://join.slack.com/t/agentflow-co/shared_invite/zt-3f712xngl-LfxS4gmftAeKvcxR3nSkWQ)
|
| 595 |
-
|
| 596 |
-
> ⏳ **Note:** The first query may take ~20 seconds to initialize AgentFlow. Subsequent queries will be super fast.
|
| 597 |
-
>
|
| 598 |
-
> 💡 **Tip:** If the wait time is too long, please try again later.
|
| 599 |
-
""")
|
| 600 |
-
|
| 601 |
-
with gr.Row():
|
| 602 |
-
# Left column for settings
|
| 603 |
-
with gr.Column(scale=1):
|
| 604 |
-
# with gr.Row():
|
| 605 |
-
# if args.openai_api_source == "user_provided":
|
| 606 |
-
# print("Using API key from user input.")
|
| 607 |
-
# api_key = gr.Textbox(
|
| 608 |
-
# show_label=True,
|
| 609 |
-
# placeholder="Your API key will not be stored in any way.",
|
| 610 |
-
# type="password",
|
| 611 |
-
# label="OpenAI API Key",
|
| 612 |
-
# # container=False
|
| 613 |
-
# )
|
| 614 |
-
# else:
|
| 615 |
-
# print(f"Using local API key from environment variable: ...{os.getenv('OPENAI_API_KEY')[-4:]}")
|
| 616 |
-
# api_key = gr.Textbox(
|
| 617 |
-
# value=os.getenv("OPENAI_API_KEY"),
|
| 618 |
-
# visible=True,
|
| 619 |
-
# interactive=False
|
| 620 |
-
# )
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
with gr.Row():
|
| 624 |
-
llm_model_engine = gr.Textbox(
|
| 625 |
-
value="vllm-AgentFlow/agentflow-planner-7b",
|
| 626 |
-
label="🧭 Planner Model",
|
| 627 |
-
interactive=False
|
| 628 |
-
)
|
| 629 |
-
|
| 630 |
-
with gr.Row():
|
| 631 |
-
gr.Textbox(
|
| 632 |
-
value="Qwen2.5-7B-Instruct",
|
| 633 |
-
label="🛠 Executor, ✅ Verifier, and ✍️ Generator Model",
|
| 634 |
-
interactive=False
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
with gr.Row():
|
| 639 |
-
vllm_status = gr.Textbox(
|
| 640 |
-
value=get_vllm_status(),
|
| 641 |
-
label="vLLM Status",
|
| 642 |
-
interactive=False,
|
| 643 |
-
scale=4
|
| 644 |
-
)
|
| 645 |
-
refresh_status_btn = gr.Button("🔄 Refresh", scale=1)
|
| 646 |
-
|
| 647 |
-
# Add click handler for refresh button
|
| 648 |
-
refresh_status_btn.click(
|
| 649 |
-
fn=get_vllm_status,
|
| 650 |
-
inputs=[],
|
| 651 |
-
outputs=vllm_status
|
| 652 |
-
)
|
| 653 |
-
|
| 654 |
-
with gr.Row():
|
| 655 |
-
max_steps = gr.Slider(value=5, minimum=1, maximum=10, step=1, label="Max Steps")
|
| 656 |
-
|
| 657 |
-
with gr.Row():
|
| 658 |
-
max_time = gr.Slider(value=240, minimum=60, maximum=300, step=30, label="Max Time (seconds)")
|
| 659 |
-
|
| 660 |
-
with gr.Row():
|
| 661 |
-
# Container for tools section
|
| 662 |
-
with gr.Column():
|
| 663 |
-
|
| 664 |
-
# First row for checkbox group
|
| 665 |
-
enabled_tools = gr.CheckboxGroup(
|
| 666 |
-
choices=all_tools,
|
| 667 |
-
value=all_tools,
|
| 668 |
-
label="Selected Tools",
|
| 669 |
-
)
|
| 670 |
-
|
| 671 |
-
# Second row for buttons
|
| 672 |
-
with gr.Row():
|
| 673 |
-
enable_all_btn = gr.Button("Select All Tools")
|
| 674 |
-
disable_all_btn = gr.Button("Clear All Tools")
|
| 675 |
-
|
| 676 |
-
# Add click handlers for the buttons
|
| 677 |
-
enable_all_btn.click(
|
| 678 |
-
lambda: all_tools,
|
| 679 |
-
outputs=enabled_tools
|
| 680 |
-
)
|
| 681 |
-
disable_all_btn.click(
|
| 682 |
-
lambda: [],
|
| 683 |
-
outputs=enabled_tools
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
with gr.Column(scale=5):
|
| 687 |
-
|
| 688 |
-
with gr.Row():
|
| 689 |
-
# Middle column for the query
|
| 690 |
-
with gr.Column(scale=2):
|
| 691 |
-
with gr.Row():
|
| 692 |
-
user_query = gr.Textbox(value="How many r letters are in the word strawberry?", placeholder="Type your question here...", label="Question (Required)", lines=3)
|
| 693 |
-
|
| 694 |
-
with gr.Row():
|
| 695 |
-
run_button = gr.Button("🌀💫 Submit and Run", variant="primary") # Run button with blue color
|
| 696 |
-
|
| 697 |
-
# Right column for the output
|
| 698 |
-
with gr.Column(scale=3):
|
| 699 |
-
chatbot_output = gr.Chatbot(type="messages", label="Step-wise Problem-Solving Output", height=500)
|
| 700 |
-
|
| 701 |
-
# TODO: Add actions to the buttons
|
| 702 |
-
with gr.Row(elem_id="buttons") as button_row:
|
| 703 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=True, variant="primary")
|
| 704 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=True, variant="primary")
|
| 705 |
-
# stop_btn = gr.Button(value="⛔️ Stop", interactive=True) # TODO
|
| 706 |
-
# clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) # TODO
|
| 707 |
-
|
| 708 |
-
# TODO: Add comment textbox
|
| 709 |
-
with gr.Row():
|
| 710 |
-
comment_textbox = gr.Textbox(value="",
|
| 711 |
-
placeholder="Feel free to add any comments here. Thanks for using AgentFlow!",
|
| 712 |
-
label="💬 Comment (Type and press Enter to submit.)", interactive=True)
|
| 713 |
-
|
| 714 |
-
# Update the button click handlers
|
| 715 |
-
upvote_btn.click(
|
| 716 |
-
fn=lambda: (save_feedback(QUERY_ID, "upvote"), gr.Info("Thank you for your upvote! 🙌")),
|
| 717 |
-
inputs=[],
|
| 718 |
-
outputs=[]
|
| 719 |
-
)
|
| 720 |
-
|
| 721 |
-
downvote_btn.click(
|
| 722 |
-
fn=lambda: (save_feedback(QUERY_ID, "downvote"), gr.Info("Thank you for your feedback. We'll work to improve! 🙏")),
|
| 723 |
-
inputs=[],
|
| 724 |
-
outputs=[]
|
| 725 |
-
)
|
| 726 |
-
|
| 727 |
-
# Add handler for comment submission
|
| 728 |
-
comment_textbox.submit(
|
| 729 |
-
fn=lambda comment: (save_feedback(QUERY_ID, "comment", comment), gr.Info("Thank you for your comment! ✨")),
|
| 730 |
-
inputs=[comment_textbox],
|
| 731 |
-
outputs=[]
|
| 732 |
-
)
|
| 733 |
-
|
| 734 |
-
# Bottom row for examples
|
| 735 |
-
with gr.Row():
|
| 736 |
-
with gr.Column(scale=5):
|
| 737 |
-
gr.Markdown("")
|
| 738 |
-
gr.Markdown("""
|
| 739 |
-
## 🚀 Try these examples with suggested tools.
|
| 740 |
-
""")
|
| 741 |
-
gr.Examples(
|
| 742 |
-
examples=[
|
| 743 |
-
[ "General Knowledge",
|
| 744 |
-
"What is the capital of France?",
|
| 745 |
-
["Base_Generator_Tool"],
|
| 746 |
-
"Paris"],
|
| 747 |
-
|
| 748 |
-
[ "Logical Reasoning",
|
| 749 |
-
"How many r letters are in the word strawberry?",
|
| 750 |
-
["Base_Generator_Tool", "Python_Coder_Tool"],
|
| 751 |
-
"3"],
|
| 752 |
-
|
| 753 |
-
[ "Web Search",
|
| 754 |
-
"Who is the mother-in-law of Olivera Despina?",
|
| 755 |
-
["Base_Generator_Tool", "Google_Search_Tool", "Wikipedia_Search_Tool", "Web_Search_Tool"],
|
| 756 |
-
"Gülçiçek Hatun"],
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
[ "Agentic Search",
|
| 760 |
-
"The object in the British Museum's collection with a museum number of 2012,5015.17 is the shell of a particular mollusk species. According to the abstract of a research article published in Science Advances in 2021, beads made from the shells of this species were found that are at least how many thousands of years old?",
|
| 761 |
-
["Base_Generator_Tool", "Python_Coder_Tool", "Google_Search_Tool", "Wikipedia_Search_Tool", "Web_Search_Tool"],
|
| 762 |
-
"142,000"],
|
| 763 |
-
|
| 764 |
-
[ "Arithmetic Reasoning",
|
| 765 |
-
"Which is bigger, 9.11 or 9.9?",
|
| 766 |
-
["Base_Generator_Tool", "Python_Coder_Tool"],
|
| 767 |
-
"9.9"],
|
| 768 |
-
|
| 769 |
-
[ "Multi-step Reasoning",
|
| 770 |
-
"Using the numbers [1, 1, 6, 9], create an expression that equals 24. You must use basic arithmetic operations (+, -, ×, /) and parentheses. For example, one solution for [1, 2, 3, 4] is (1+2+3)×4.",
|
| 771 |
-
["Python_Coder_Tool"],
|
| 772 |
-
"((1 + 1) * 9) + 6"],
|
| 773 |
-
|
| 774 |
-
["Scentific Reasoning",
|
| 775 |
-
"An investigator is studying cellular regeneration of epithelial cells. She has obtained a tissue sample from a normal thyroid gland for histopathologic examination. It shows follicles lined by a single layer of cube-like cells with large central nuclei. Which of the following parts of the female reproductive tract is also lined by this type of epithelium?\nA. Ovaries\nB. Vagina\nC. Fallopian tubes\nD. Vulva\nChoose the correct option.",
|
| 776 |
-
["Base_Generator_Tool", "Google_Search_Tool", "Wikipedia_Search_Tool", "Python_Coder_Tool"],
|
| 777 |
-
"A. Ovaries"],
|
| 778 |
-
],
|
| 779 |
-
inputs=[gr.Textbox(label="Category", visible=False), user_query, enabled_tools, gr.Textbox(label="Reference Answer", visible=False)],
|
| 780 |
-
# label="Try these examples with suggested tools."
|
| 781 |
-
)
|
| 782 |
-
|
| 783 |
-
# Link button click to function
|
| 784 |
-
run_button.click(
|
| 785 |
-
fn=solve_problem_gradio,
|
| 786 |
-
inputs=[user_query, max_steps, max_time, llm_model_engine, enabled_tools],
|
| 787 |
-
outputs=chatbot_output,
|
| 788 |
-
concurrency_limit=10, # A10 GPU can handle ~10 concurrent requests with vLLM
|
| 789 |
-
concurrency_id="agentflow_solver" # Shared queue for managing GPU resource
|
| 790 |
-
)
|
| 791 |
-
#################### Gradio Interface ####################
|
| 792 |
-
|
| 793 |
-
# Configure queue for high traffic - optimized for A10 GPU (40G RAM, 24G VRAM)
|
| 794 |
-
demo.queue(
|
| 795 |
-
default_concurrency_limit=10, # Balanced for A10 GPU + vLLM inference
|
| 796 |
-
max_size=50, # Allow up to 20 requests in queue for traffic spikes
|
| 797 |
-
)
|
| 798 |
-
|
| 799 |
-
# Launch the Gradio app with optimized threading
|
| 800 |
-
# demo.launch(ssr_mode=False)
|
| 801 |
-
demo.launch(
|
| 802 |
-
ssr_mode=False,
|
| 803 |
-
share=True,
|
| 804 |
-
max_threads=80 # Increase from default 40 to support high concurrency
|
| 805 |
)
|
| 806 |
|
|
|
|
| 807 |
if __name__ == "__main__":
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
args = parse_arguments()
|
| 811 |
-
|
| 812 |
-
# All tools for AgentFlow
|
| 813 |
-
all_tools = [
|
| 814 |
-
"Base_Generator_Tool",
|
| 815 |
-
"Python_Coder_Tool",
|
| 816 |
-
"Google_Search_Tool",
|
| 817 |
-
"Wikipedia_Search_Tool",
|
| 818 |
-
"Web_Search_Tool"
|
| 819 |
-
]
|
| 820 |
-
args.enabled_tools = ",".join(all_tools)
|
| 821 |
-
|
| 822 |
-
# NOTE: Use the same name for the query cache directory as the dataset directory
|
| 823 |
-
args.root_cache_dir = DATASET_DIR.name
|
| 824 |
-
|
| 825 |
-
# Start vLLM service
|
| 826 |
-
print("=" * 60)
|
| 827 |
-
print("🔍 Checking vLLM service status...")
|
| 828 |
-
if not check_vllm_service():
|
| 829 |
-
print(f"⚠️ vLLM service not running. Starting {VLLM_MODEL_NAME}...")
|
| 830 |
-
start_vllm_service()
|
| 831 |
-
else:
|
| 832 |
-
print(f"✅ vLLM service is already running on port {VLLM_PORT}")
|
| 833 |
-
print("=" * 60)
|
| 834 |
-
|
| 835 |
-
# Register cleanup function
|
| 836 |
-
# atexit.register(stop_vllm_service)
|
| 837 |
-
|
| 838 |
-
main(args)
|
| 839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
def extract_canny_edges(video_path, low_threshold=50, high_threshold=150):
|
| 9 |
"""
|
| 10 |
+
استخراج Canny edges از ویدیو
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
+
cap = cv2.VideoCapture(video_path)
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
# دریافت اطلاعات ویدیو
|
| 15 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 16 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 17 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# ساخت فایل خروجی موقت
|
| 20 |
+
output_path = tempfile.mktemp(suffix='.mp4')
|
| 21 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 22 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 23 |
|
| 24 |
+
frame_count = 0
|
| 25 |
+
canny_frames = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
while True:
|
| 28 |
+
ret, frame = cap.read()
|
| 29 |
+
if not ret:
|
| 30 |
+
break
|
| 31 |
+
|
| 32 |
+
# تبدیل به Grayscale
|
| 33 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 34 |
+
|
| 35 |
+
# اعمال Gaussian Blur برای کاهش نویز
|
| 36 |
+
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
| 37 |
+
|
| 38 |
+
# استخراج Canny edges
|
| 39 |
+
edges = cv2.Canny(blurred, low_threshold, high_threshold)
|
| 40 |
+
|
| 41 |
+
# تبدیل به BGR برای ذخیره
|
| 42 |
+
edges_bgr = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
|
| 43 |
+
|
| 44 |
+
# نوشتن فریم در ویدیوی خروجی
|
| 45 |
+
out.write(edges_bgr)
|
| 46 |
+
|
| 47 |
+
# ذخیره برای پیشنمایش
|
| 48 |
+
if frame_count % 5 == 0: # هر 5 فریم یکی
|
| 49 |
+
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
| 50 |
+
canny_frames.append(Image.fromarray(edges_rgb))
|
| 51 |
+
|
| 52 |
+
frame_count += 1
|
| 53 |
+
|
| 54 |
+
cap.release()
|
| 55 |
+
out.release()
|
| 56 |
+
|
| 57 |
+
return output_path, canny_frames, frame_count, fps
|
| 58 |
|
| 59 |
+
def process_video(video_path, low_threshold, high_threshold):
|
| 60 |
+
"""
|
| 61 |
+
پردازش ویدیو و استخراج حرکات
|
| 62 |
+
"""
|
| 63 |
+
if video_path is None:
|
| 64 |
+
return None, None, "❌ لطفاً یک ویدیو آپلود کنید"
|
| 65 |
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
+
output_video, preview_frames, total_frames, fps = extract_canny_edges(
|
| 68 |
+
video_path,
|
| 69 |
+
int(low_threshold),
|
| 70 |
+
int(high_threshold)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
status = f"""
|
| 74 |
+
✅ استخراج حرکات با موفقیت انجام شد!
|
| 75 |
+
|
| 76 |
+
📊 اطلاعات:
|
| 77 |
+
• تعداد کل فریمها: {total_frames}
|
| 78 |
+
• FPS: {fps}
|
| 79 |
+
• مدت زمان: {total_frames/fps:.2f} ثانیه
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
return output_video, preview_frames, status
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
return None, None, f"❌ خطا: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# رابط Gradio
|
| 88 |
+
with gr.Blocks(title="Wan2.1 Canny Edge Extractor", theme=gr.themes.Soft()) as demo:
|
| 89 |
+
|
| 90 |
+
gr.Markdown("""
|
| 91 |
+
# 🎬 استخراج حرکات ویدیو (Canny Edge Detection)
|
| 92 |
+
|
| 93 |
+
این ابزار با استفاده از الگوریتم **Canny Edge Detection**، لبهها و حرکات ویدیو شما را استخراج میکند.
|
| 94 |
+
|
| 95 |
+
این خروجی میتواند به عنوان ورودی برای مدل **Wan2.1 ControlNet** استفاده شود.
|
| 96 |
+
""")
|
| 97 |
+
|
| 98 |
+
with gr.Row():
|
| 99 |
+
with gr.Column():
|
| 100 |
+
input_video = gr.Video(
|
| 101 |
+
label="📹 ویدیوی ورودی",
|
| 102 |
+
height=400
|
| 103 |
)
|
| 104 |
+
|
| 105 |
+
gr.Markdown("### ⚙️ تنظیمات Canny")
|
| 106 |
+
|
| 107 |
+
low_threshold = gr.Slider(
|
| 108 |
+
minimum=0,
|
| 109 |
+
maximum=255,
|
| 110 |
+
value=50,
|
| 111 |
+
step=1,
|
| 112 |
+
label="آستانه پایین (Low Threshold)",
|
| 113 |
+
info="مقدار کمتر = لبههای بیشتر"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
high_threshold = gr.Slider(
|
| 117 |
+
minimum=0,
|
| 118 |
+
maximum=255,
|
| 119 |
+
value=150,
|
| 120 |
+
step=1,
|
| 121 |
+
label="آستانه بالا (High Threshold)",
|
| 122 |
+
info="مقدار بیشتر = فقط لبههای قوی"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
process_btn = gr.Button(
|
| 126 |
+
"🚀 استخراج حرکات",
|
| 127 |
+
variant="primary",
|
| 128 |
+
size="lg"
|
| 129 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
+
with gr.Column():
|
| 132 |
+
status_text = gr.Textbox(
|
| 133 |
+
label="وضعیت",
|
| 134 |
+
lines=6,
|
| 135 |
+
interactive=False
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
output_video = gr.Video(
|
| 139 |
+
label="🎥 ویدیوی Canny Edges",
|
| 140 |
+
height=400
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
preview_gallery = gr.Gallery(
|
| 144 |
+
label="🖼️ پیشنمایش فریمها",
|
| 145 |
+
columns=4,
|
| 146 |
+
height=300
|
| 147 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
gr.Markdown("""
|
| 150 |
+
---
|
| 151 |
+
### 📖 راهنمای استفاده:
|
| 152 |
+
|
| 153 |
+
1. **آپلود ویدیو**: یک ویدیو آپلود کنید (توصیه: کمتر از 30 ثانیه)
|
| 154 |
+
2. **تنظیم آستانهها**:
|
| 155 |
+
- آستانه پایین: کنترل حساسیت تشخیص لبه
|
| 156 |
+
- آستانه بالا: فیلتر کردن لبههای ضعیف
|
| 157 |
+
3. **استخراج**: روی دکمه "استخراج حرکات" کلیک کنید
|
| 158 |
+
4. **دانلود**: ویدیوی خروجی را دانلود کنید
|
| 159 |
+
|
| 160 |
+
### 💡 نکات:
|
| 161 |
+
- **آستانه کم** (مثلاً 30-100): جزئیات بیشتر، نویز بیشتر
|
| 162 |
+
- **آستانه متوسط** (مثلاً 50-150): پیشنهادی - تعادل خوب
|
| 163 |
+
- **آستانه بالا** (مثلاً 100-200): فقط لبههای اصلی
|
| 164 |
+
|
| 165 |
+
### 🔗 استفاده در Wan2.1:
|
| 166 |
+
ویدیوی خروجی میتواند به عنوان **ControlNet conditioning** برای مدل Wan2.1 استفاده شود.
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
🔗 مدل: [TheDenk/wan2.1-t2v-1.3b-controlnet-canny-v1](https://huggingface.co/TheDenk/wan2.1-t2v-1.3b-controlnet-canny-v1)
|
| 171 |
+
""")
|
| 172 |
+
|
| 173 |
+
# اتصال دکمه به تابع
|
| 174 |
+
process_btn.click(
|
| 175 |
+
fn=process_video,
|
| 176 |
+
inputs=[input_video, low_threshold, high_threshold],
|
| 177 |
+
outputs=[output_video, preview_gallery, status_text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
|
| 180 |
+
# اجرای اپلیکیشن
|
| 181 |
if __name__ == "__main__":
|
| 182 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|