Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import random | |
| from dotenv import load_dotenv | |
| from typing import List, Dict, Any, TypedDict, Annotated | |
| import operator | |
| from langchain_core.tools import tool | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| # Load environment variables | |
| load_dotenv() | |
| # ---- Tool Definitions ---- | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two integers and return the product.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two integers and return the sum.""" | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract the second integer from the first and return the difference.""" | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide the first integer by the second and return the quotient.""" | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Return the remainder of the division of the first integer by the second.""" | |
| return a % b | |
| def optimized_web_search(query: str) -> str: | |
| """Perform an optimized web search using TavilySearchResults and return concatenated document snippets.""" | |
| try: | |
| time.sleep(random.uniform(1, 2)) | |
| search_tool = TavilySearchResults(max_results=2) | |
| docs = search_tool.invoke({"query": query}) | |
| return "\n\n---\n\n".join( | |
| f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Web search failed: {e}" | |
| def optimized_wiki_search(query: str) -> str: | |
| """Perform an optimized Wikipedia search and return concatenated document snippets.""" | |
| try: | |
| time.sleep(random.uniform(0.5, 1)) | |
| docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
| return "\n\n---\n\n".join( | |
| f"<Doc src='{d.metadata.get('source', 'Wikipedia')}'>{d.page_content[:800]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Wikipedia search failed: {e}" | |
| # ---- LLM Integrations with Error Handling ---- | |
| try: | |
| from langchain_groq import ChatGroq | |
| GROQ_AVAILABLE = True | |
| except ImportError: | |
| GROQ_AVAILABLE = False | |
| try: | |
| from langchain_nvidia_ai_endpoints import ChatNVIDIA | |
| NVIDIA_AVAILABLE = True | |
| except ImportError: | |
| NVIDIA_AVAILABLE = False | |
| try: | |
| import google.generativeai as genai | |
| GEMINI_AVAILABLE = True | |
| except ImportError: | |
| GEMINI_AVAILABLE = False | |
| import requests | |
| def deepseek_generate(prompt, api_key=None): | |
| """Call DeepSeek API.""" | |
| if not api_key: | |
| return "DeepSeek API key not provided" | |
| url = "https://api.deepseek.com/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": "deepseek-chat", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "stream": False | |
| } | |
| try: | |
| resp = requests.post(url, headers=headers, json=data, timeout=30) | |
| resp.raise_for_status() | |
| choices = resp.json().get("choices", []) | |
| if choices and "message" in choices[0]: | |
| return choices[0]["message"].get("content", "") | |
| return "No response from DeepSeek" | |
| except Exception as e: | |
| return f"DeepSeek API error: {e}" | |
| def baidu_ernie_generate(prompt, api_key=None): | |
| """Call Baidu ERNIE API (placeholder implementation).""" | |
| if not api_key: | |
| return "Baidu ERNIE API key not provided" | |
| # Note: This is a placeholder. Replace with actual Baidu ERNIE API endpoint | |
| try: | |
| return f"Baidu ERNIE response for: {prompt[:50]}..." | |
| except Exception as e: | |
| return f"ERNIE API error: {e}" | |
| # ---- Graph State ---- | |
| class EnhancedAgentState(TypedDict): | |
| messages: Annotated[List[HumanMessage|AIMessage], operator.add] | |
| query: str | |
| agent_type: str | |
| final_answer: str | |
| perf: Dict[str,Any] | |
| agno_resp: str | |
| class HybridLangGraphMultiLLMSystem: | |
| def __init__(self, provider="groq"): | |
| self.provider = provider | |
| self.tools = [ | |
| multiply, add, subtract, divide, modulus, | |
| optimized_web_search, optimized_wiki_search | |
| ] | |
| self.graph = self._build_graph() | |
| def _build_graph(self): | |
| # Initialize LLMs with error handling | |
| groq_llm = None | |
| nvidia_llm = None | |
| if GROQ_AVAILABLE and os.getenv("GROQ_API_KEY"): | |
| try: | |
| groq_llm = ChatGroq( | |
| model="llama3-70b-8192", | |
| temperature=0, | |
| api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| except Exception as e: | |
| print(f"Failed to initialize Groq: {e}") | |
| if NVIDIA_AVAILABLE and os.getenv("NVIDIA_API_KEY"): | |
| try: | |
| nvidia_llm = ChatNVIDIA( | |
| model="meta/llama3-70b-instruct", | |
| temperature=0, | |
| api_key=os.getenv("NVIDIA_API_KEY") | |
| ) | |
| except Exception as e: | |
| print(f"Failed to initialize NVIDIA: {e}") | |
| def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
| q = st["query"].lower() | |
| if "groq" in q and groq_llm: | |
| t = "groq" | |
| elif "nvidia" in q and nvidia_llm: | |
| t = "nvidia" | |
| elif ("gemini" in q or "google" in q) and GEMINI_AVAILABLE: | |
| t = "gemini" | |
| elif "deepseek" in q: | |
| t = "deepseek" | |
| elif "ernie" in q or "baidu" in q: | |
| t = "baidu" | |
| else: | |
| # Default to first available provider | |
| if groq_llm: | |
| t = "groq" | |
| elif nvidia_llm: | |
| t = "nvidia" | |
| elif GEMINI_AVAILABLE: | |
| t = "gemini" | |
| else: | |
| t = "deepseek" | |
| return {**st, "agent_type": t} | |
| def groq_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| if not groq_llm: | |
| return {**st, "final_answer": "Groq not available", "perf": {"error": "No Groq LLM"}} | |
| t0 = time.time() | |
| try: | |
| sys = SystemMessage(content="You are a helpful AI assistant. Provide accurate and detailed answers.") | |
| res = groq_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
| return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Groq error: {e}", "perf": {"error": str(e)}} | |
| def nvidia_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| if not nvidia_llm: | |
| return {**st, "final_answer": "NVIDIA not available", "perf": {"error": "No NVIDIA LLM"}} | |
| t0 = time.time() | |
| try: | |
| sys = SystemMessage(content="You are a helpful AI assistant. Provide accurate and detailed answers.") | |
| res = nvidia_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
| return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "NVIDIA"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"NVIDIA error: {e}", "perf": {"error": str(e)}} | |
| def gemini_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| if not GEMINI_AVAILABLE: | |
| return {**st, "final_answer": "Gemini not available", "perf": {"error": "Gemini not installed"}} | |
| t0 = time.time() | |
| try: | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| return {**st, "final_answer": "Gemini API key not provided", "perf": {"error": "No API key"}} | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel("gemini-1.5-pro-latest") | |
| res = model.generate_content(st["query"]) | |
| return {**st, "final_answer": res.text, "perf": {"time": time.time() - t0, "prov": "Gemini"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Gemini error: {e}", "perf": {"error": str(e)}} | |
| def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| t0 = time.time() | |
| try: | |
| resp = deepseek_generate(st["query"], api_key=os.getenv("DEEPSEEK_API_KEY")) | |
| return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"DeepSeek error: {e}", "perf": {"error": str(e)}} | |
| def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| t0 = time.time() | |
| try: | |
| resp = baidu_ernie_generate(st["query"], api_key=os.getenv("BAIDU_API_KEY")) | |
| return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "ERNIE"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"ERNIE error: {e}", "perf": {"error": str(e)}} | |
| def pick(st: EnhancedAgentState) -> str: | |
| return st["agent_type"] | |
| g = StateGraph(EnhancedAgentState) | |
| g.add_node("router", router) | |
| g.add_node("groq", groq_node) | |
| g.add_node("nvidia", nvidia_node) | |
| g.add_node("gemini", gemini_node) | |
| g.add_node("deepseek", deepseek_node) | |
| g.add_node("baidu", baidu_node) | |
| g.set_entry_point("router") | |
| g.add_conditional_edges("router", pick, { | |
| "groq": "groq", | |
| "nvidia": "nvidia", | |
| "gemini": "gemini", | |
| "deepseek": "deepseek", | |
| "baidu": "baidu" | |
| }) | |
| for n in ["groq", "nvidia", "gemini", "deepseek", "baidu"]: | |
| g.add_edge(n, END) | |
| return g.compile(checkpointer=MemorySaver()) | |
| def process_query(self, q: str) -> str: | |
| state = { | |
| "messages": [HumanMessage(content=q)], | |
| "query": q, | |
| "agent_type": "", | |
| "final_answer": "", | |
| "perf": {}, | |
| "agno_resp": "" | |
| } | |
| cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} | |
| try: | |
| out = self.graph.invoke(state, cfg) | |
| raw_answer = out.get("final_answer", "No answer generated") | |
| # Clean up the answer | |
| if isinstance(raw_answer, str): | |
| parts = raw_answer.split('\n\n') | |
| answer_part = parts[1].strip() if len(parts) > 1 and len(parts[1].strip()) > 10 else raw_answer.strip() | |
| return answer_part | |
| return str(raw_answer) | |
| except Exception as e: | |
| return f"Error processing query: {e}" | |
| # Function expected by app.py | |
| def build_graph(provider="groq"): | |
| """Build and return the graph for the agent system.""" | |
| system = HybridLangGraphMultiLLMSystem(provider=provider) | |
| return system.graph | |
| if __name__ == "__main__": | |
| query = "What are the names of the US presidents who were assassinated?" | |
| system = HybridLangGraphMultiLLMSystem() | |
| result = system.process_query(query) | |
| print("LangGraph Hybrid Result:", result) | |