Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import random | |
| import operator | |
| from typing import List, Dict, Any, TypedDict, Annotated | |
| from dotenv import load_dotenv | |
| from langchain_core.tools import tool | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_groq import ChatGroq | |
| load_dotenv() # expects GROQ_API_KEY in your .env | |
| def multiply(a: int, b: int) -> int: | |
| '''multiplies two numbers''' | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| '''adds two numbers''' | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| '''subtracts two numbers''' | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| '''divides two numbers''' | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| '''returns the remainder while dividing two numbers''' | |
| return a % b | |
| def optimized_web_search(query: str) -> str: | |
| '''searches the web using tavily''' | |
| try: | |
| time.sleep(random.uniform(0.7, 1.5)) | |
| docs = TavilySearchResults(max_results=2).invoke(query=query) | |
| return "\n\n---\n\n".join( | |
| f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Web search failed: {e}" | |
| def optimized_wiki_search(query: str) -> str: | |
| '''searches wikipedia''' | |
| try: | |
| time.sleep(random.uniform(0.3, 1)) | |
| docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
| return "\n\n---\n\n".join( | |
| f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:800]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Wikipedia search failed: {e}" | |
| class EnhancedAgentState(TypedDict): | |
| messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
| query: str | |
| agent_type: str | |
| final_answer: str | |
| perf: Dict[str, Any] | |
| agno_resp: str | |
| class HybridLangGraphMultiLLMSystem: | |
| """ | |
| Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default), | |
| and Groq-hosted DeepSeek-Chat according to the query content. | |
| """ | |
| def __init__(self): | |
| self.tools = [ | |
| multiply, add, subtract, divide, modulus, | |
| optimized_web_search, optimized_wiki_search | |
| ] | |
| self.graph = self._build_graph() | |
| def _llm(self, model_name: str): | |
| return ChatGroq( | |
| model=model_name, | |
| temperature=0, | |
| api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| def _build_graph(self): | |
| llama8_llm = self._llm("llama3-8b-8192") | |
| llama70_llm = self._llm("llama3-70b-8192") | |
| deepseek_llm = self._llm("deepseek-chat") | |
| def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
| q = st["query"].lower() | |
| if "llama-8" in q: | |
| t = "llama8" | |
| elif "deepseek" in q: | |
| t = "deepseek" | |
| else: | |
| t = "llama70" | |
| return {**st, "agent_type": t} | |
| def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| t0 = time.time() | |
| sys = SystemMessage(content="You are a helpful AI assistant.") | |
| res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
| return {**st, | |
| "final_answer": res.content, | |
| "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} | |
| def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| t0 = time.time() | |
| sys = SystemMessage(content="You are a helpful AI assistant.") | |
| res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
| return {**st, | |
| "final_answer": res.content, | |
| "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} | |
| def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| t0 = time.time() | |
| sys = SystemMessage(content="You are a helpful AI assistant.") | |
| res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
| return {**st, | |
| "final_answer": res.content, | |
| "perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} | |
| g = StateGraph(EnhancedAgentState) | |
| g.add_node("router", router) | |
| g.add_node("llama8", llama8_node) | |
| g.add_node("llama70", llama70_node) | |
| g.add_node("deepseek", deepseek_node) | |
| g.set_entry_point("router") | |
| g.add_conditional_edges("router", lambda s: s["agent_type"], | |
| {"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"}) | |
| g.add_edge("llama8", END) | |
| g.add_edge("llama70", END) | |
| g.add_edge("deepseek", END) | |
| return g.compile(checkpointer=MemorySaver()) | |
| def process_query(self, q: str) -> str: | |
| state = { | |
| "messages": [HumanMessage(content=q)], | |
| "query": q, | |
| "agent_type": "", | |
| "final_answer": "", | |
| "perf": {}, | |
| "agno_resp": "" | |
| } | |
| cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} | |
| out = self.graph.invoke(state, cfg) | |
| return out.get("final_answer", "").strip() | |
| def build_graph(provider: str | None = None): | |
| return HybridLangGraphMultiLLMSystem().graph | |
| if __name__ == "__main__": | |
| qa_system = HybridLangGraphMultiLLMSystem() | |
| # Test each model | |
| print(qa_system.process_query("llama-8: What is the capital of France?")) | |
| print(qa_system.process_query("llama-70: Tell me about quantum mechanics.")) | |
| print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?")) | |