|
|
|
|
|
|
|
|
import os |
|
|
from typing import TypedDict, Annotated, List |
|
|
import operator |
|
|
from dotenv import load_dotenv |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_anthropic import ChatAnthropic |
|
|
from langchain.schema import BaseMessage, AIMessage |
|
|
from langgraph.graph import StateGraph, END |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphState(TypedDict): |
|
|
""" |
|
|
Represents the state of our graph. |
|
|
|
|
|
Attributes: |
|
|
messages: The list of messages comprising the conversation. |
|
|
operator.add indicates messages should be appended. |
|
|
""" |
|
|
messages: Annotated[List[BaseMessage], operator.add] |
|
|
|
|
|
|
|
|
|
|
|
def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str): |
|
|
"""Initializes the appropriate LangChain Chat Model.""" |
|
|
if provider == "OpenAI": |
|
|
if not api_key: |
|
|
raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.") |
|
|
return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature) |
|
|
elif provider == "Anthropic": |
|
|
if not api_key: |
|
|
raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.") |
|
|
return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature) |
|
|
else: |
|
|
raise ValueError(f"Unsupported LLM provider: {provider}") |
|
|
|
|
|
|
|
|
|
|
|
def create_chat_graph(llm): |
|
|
""" |
|
|
Builds and compiles the LangGraph conversational graph. |
|
|
|
|
|
Args: |
|
|
llm: An initialized LangChain Chat Model instance. |
|
|
|
|
|
Returns: |
|
|
A compiled LangGraph application. |
|
|
""" |
|
|
|
|
|
|
|
|
def call_model(state: GraphState) -> dict: |
|
|
"""Invokes the provided LLM with the current conversation state.""" |
|
|
messages = state['messages'] |
|
|
response = llm.invoke(messages) |
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
|
|
|
workflow.add_node("llm_node", call_model) |
|
|
|
|
|
|
|
|
workflow.set_entry_point("llm_node") |
|
|
workflow.add_edge("llm_node", END) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph = workflow.compile() |
|
|
|
|
|
return graph |
|
|
|
|
|
|