File size: 2,866 Bytes
74c02ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# --- START OF FILE graph_logic.py ---

import os
from typing import TypedDict, Annotated, List
import operator
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain.schema import BaseMessage, AIMessage
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver # Optional for checkpointing

# Load environment variables (optional here, but good practice if testing independently)
# load_dotenv() # Can be commented out if only app_main.py loads it

# --- LangGraph State Definition ---

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        messages: The list of messages comprising the conversation.
                  operator.add indicates messages should be appended.
    """
    messages: Annotated[List[BaseMessage], operator.add]

# --- LLM Initialization ---

def initialize_llm(provider: str, model_name: str, temperature: float, api_key: str):
    """Initializes the appropriate LangChain Chat Model."""
    if provider == "OpenAI":
        if not api_key:
            raise ValueError("OpenAI API key is missing. Please set OPENAI_API_KEY.")
        return ChatOpenAI(api_key=api_key, model_name=model_name, temperature=temperature)
    elif provider == "Anthropic":
        if not api_key:
            raise ValueError("Anthropic API key is missing. Please set ANTHROPIC_API_KEY.")
        return ChatAnthropic(api_key=api_key, model_name=model_name, temperature=temperature)
    else:
        raise ValueError(f"Unsupported LLM provider: {provider}")

# --- LangGraph Node and Graph Building ---

def create_chat_graph(llm):
    """
    Builds and compiles the LangGraph conversational graph.

    Args:
        llm: An initialized LangChain Chat Model instance.

    Returns:
        A compiled LangGraph application.
    """

    # Define the function that calls the LLM - it closes over the 'llm' variable
    def call_model(state: GraphState) -> dict:
        """Invokes the provided LLM with the current conversation state."""
        messages = state['messages']
        response = llm.invoke(messages)
        # Return the AIMessage list to be added to the state
        return {"messages": [response]}

    # Build the graph workflow
    workflow = StateGraph(GraphState)

    # Add the single node that runs the LLM
    workflow.add_node("llm_node", call_model)

    # Set the entry point and the only edge
    workflow.set_entry_point("llm_node")
    workflow.add_edge("llm_node", END) # Conversation ends after one LLM call per turn

    # Compile the graph
    # Optional: Add memory for checkpointing if needed
    # memory = MemorySaver()
    # graph = workflow.compile(checkpointer=memory)
    graph = workflow.compile()

    return graph

# --- END OF FILE graph_logic.py ---