File size: 2,866 Bytes
74c02ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# --- START OF FILE graph_logic.py ---
import os
from typing import TypedDict, Annotated, List
import operator
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.schema import BaseMessage, AIMessage
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver # Optional for checkpointing
# Load environment variables (optional here, but good practice if testing independently)
# load_dotenv() # Can be commented out if only app_main.py loads it
# --- LangGraph State Definition ---
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
messages: The list of messages comprising the conversation.
operator.add indicates messages should be appended.
"""
messages: Annotated[List[BaseMessage], operator.add]
# --- LLM Initialization ---
def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str):
"""Initializes the appropriate LangChain Chat Model."""
if provider == "OpenAI":
if not api_key:
raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.")
return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature)
elif provider == "Anthropic":
if not api_key:
raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.")
return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
# --- LangGraph Node and Graph Building ---
def create_chat_graph(llm):
"""
Builds and compiles the LangGraph conversational graph.
Args:
llm: An initialized LangChain Chat Model instance.
Returns:
A compiled LangGraph application.
"""
# Define the function that calls the LLM - it closes over the 'llm' variable
def call_model(state: GraphState) -> dict:
"""Invokes the provided LLM with the current conversation state."""
messages = state['messages']
response = llm.invoke(messages)
# Return the AIMessage list to be added to the state
return {"messages": [response]}
# Build the graph workflow
workflow = StateGraph(GraphState)
# Add the single node that runs the LLM
workflow.add_node("llm_node", call_model)
# Set the entry point and the only edge
workflow.set_entry_point("llm_node")
workflow.add_edge("llm_node", END) # Conversation ends after one LLM call per turn
# Compile the graph
# Optional: Add memory for checkpointing if needed
# memory = MemorySaver()
# graph = workflow.compile(checkpointer=memory)
graph = workflow.compile()
return graph
# --- END OF FILE graph_logic.py --- |