Spaces:
Sleeping
Sleeping
| import logging | |
| import re | |
| from typing import List, Dict, Any | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver # Or SqliteSaver etc. | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| from .config import settings | |
| from .schemas import PlannerState, KeyIssue, GraphConfig # Import schemas | |
| from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT | |
| from .llm_interface import get_llm | |
| from .graph_operations import ( | |
| generate_cypher_auto, generate_cypher_guided, | |
| retrieve_documents, evaluate_documents | |
| ) | |
| from .processing import process_documents | |
| logger = logging.getLogger(__name__) | |
| # --- Graph Nodes --- | |
| def start_planning(state: PlannerState) -> Dict[str, Any]: | |
| """Generates the initial plan based on the user query.""" | |
| logger.info("Node: start_planning") | |
| user_query = state['user_query'] | |
| if not user_query: | |
| return {"error": "User query is empty."} | |
| initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) | |
| llm = get_llm(settings.main_llm_model) | |
| chain = initial_prompt | llm | StrOutputParser() | |
| try: | |
| plan_text = chain.invoke({}) # Prompt already includes query | |
| logger.debug(f"Raw plan text: {plan_text}") | |
| # Extract plan steps (simple regex, might need refinement) | |
| plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE) | |
| if plan_match: | |
| plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] | |
| logger.info(f"Extracted plan: {plan_steps}") | |
| return { | |
| "plan": plan_steps, | |
| "current_plan_step_index": 0, | |
| "messages": [AIMessage(content=plan_text)], | |
| "step_outputs": {} # Initialize step outputs | |
| } | |
| else: | |
| logger.error("Could not parse plan from LLM response.") | |
| return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} | |
| except Exception as e: | |
| logger.error(f"Error during plan generation: {e}", exc_info=True) | |
| return {"error": f"LLM error during plan generation: {e}"} | |
| def execute_plan_step(state: PlannerState) -> Dict[str, Any]: | |
| """Executes the current step of the plan (retrieval, processing).""" | |
| current_index = state['current_plan_step_index'] | |
| plan = state['plan'] | |
| user_query = state['user_query'] # Use original query for context | |
| if current_index >= len(plan): | |
| logger.warning("Plan step index out of bounds, attempting to finalize.") | |
| # This should ideally be handled by the conditional edge, but as a fallback | |
| return {"error": "Plan execution finished unexpectedly."} | |
| step_description = plan[current_index] | |
| logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") | |
| # --- Determine Query for Retrieval --- | |
| # Simple approach: Use step description or original query? | |
| # Let's use the step description combined with the original query for context. | |
| query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" | |
| logger.info(f"Query for retrieval: {query_for_retrieval}") | |
| # --- Generate Cypher --- | |
| cypher_query = "" | |
| if settings.cypher_gen_method == 'auto': | |
| cypher_query = generate_cypher_auto(query_for_retrieval) | |
| elif settings.cypher_gen_method == 'guided': | |
| cypher_query = generate_cypher_guided(query_for_retrieval, current_index) | |
| # TODO: Add cypher validation if settings.validate_cypher is True | |
| # --- Retrieve Documents --- | |
| retrieved_docs = retrieve_documents(cypher_query) | |
| # --- Evaluate Documents --- | |
| evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) | |
| # --- Process Documents --- | |
| # Using configured processing steps | |
| processed_docs_content = process_documents(evaluated_docs, settings.process_steps) | |
| # --- Store Step Output --- | |
| # Store the processed content relevant to this step | |
| step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." | |
| current_step_outputs = state.get('step_outputs', {}) | |
| current_step_outputs[current_index] = step_output | |
| logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") | |
| return { | |
| "current_plan_step_index": current_index + 1, | |
| "messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], # Add summary message | |
| "step_outputs": current_step_outputs | |
| } | |
| def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: | |
| """Generates the final structured Key Issues based on all gathered context.""" | |
| logger.info("Node: generate_structured_issues") | |
| user_query = state['user_query'] | |
| step_outputs = state.get('step_outputs', {}) | |
| # --- Combine Context from All Steps --- | |
| full_context = f"Original User Query: {user_query}\n\n" | |
| full_context += "Context gathered during planning:\n" | |
| for i, output in sorted(step_outputs.items()): | |
| full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" | |
| if not step_outputs: | |
| full_context += "No context was gathered during the planning steps.\n" | |
| logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") | |
| # logger.debug(f"Full Context for Key Issue Generation:\n{full_context}") # Optional: log full context | |
| # --- Call LLM for Structured Output --- | |
| issue_llm = get_llm(settings.main_llm_model) | |
| # Use PydanticOutputParser for robust parsing | |
| output_parser = JsonOutputParser(pydantic_object=List[KeyIssue]) | |
| prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( | |
| # schema=output_parser.get_format_instructions(), # Inject schema instructions if needed by prompt | |
| ) | |
| chain = prompt | issue_llm | output_parser | |
| try: | |
| structured_issues = chain.invoke({ | |
| "user_query": user_query, | |
| "context": full_context | |
| }) | |
| # Ensure IDs are sequential if the LLM didn't assign them correctly | |
| for i, issue in enumerate(structured_issues): | |
| issue.id = i + 1 | |
| logger.info(f"Successfully generated {len(structured_issues)} structured key issues.") | |
| final_message = f"Generated {len(structured_issues)} Key Issues based on the query '{user_query}'." | |
| return { | |
| "key_issues": structured_issues, | |
| "messages": [AIMessage(content=final_message)], # Final summary message | |
| "error": None # Clear any previous errors | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) | |
| # Attempt to get raw output for debugging if possible | |
| raw_output = "Could not retrieve raw output." | |
| try: | |
| raw_chain = prompt | issue_llm | StrOutputParser() | |
| raw_output = raw_chain.invoke({"user_query": user_query, "context": full_context}) | |
| logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") | |
| except Exception as raw_e: | |
| logger.error(f"Could not even get raw output: {raw_e}") | |
| return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} | |
| # --- Conditional Edges --- | |
| def should_continue_planning(state: PlannerState) -> str: | |
| """Determines if there are more plan steps to execute.""" | |
| logger.debug("Edge: should_continue_planning") | |
| if state.get("error"): | |
| logger.error(f"Error state detected: {state['error']}. Ending execution.") | |
| return "error_state" # Go to a potential error handling end node | |
| current_index = state['current_plan_step_index'] | |
| plan_length = len(state.get('plan', [])) | |
| if current_index < plan_length: | |
| logger.debug(f"Continuing plan execution. Next step index: {current_index}") | |
| return "continue_execution" | |
| else: | |
| logger.debug("Plan finished. Proceeding to final generation.") | |
| return "finalize" | |
| # --- Build Graph --- | |
| def build_graph(): | |
| """Builds the LangGraph workflow.""" | |
| workflow = StateGraph(PlannerState) | |
| # Add nodes | |
| workflow.add_node("start_planning", start_planning) | |
| workflow.add_node("execute_plan_step", execute_plan_step) | |
| workflow.add_node("generate_issues", generate_structured_issues) | |
| # Optional: Add an error handling node | |
| workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) | |
| # Define edges | |
| workflow.set_entry_point("start_planning") | |
| workflow.add_edge("start_planning", "execute_plan_step") # Assume plan is always generated | |
| workflow.add_conditional_edges( | |
| "execute_plan_step", | |
| should_continue_planning, | |
| { | |
| "continue_execution": "execute_plan_step", # Loop back to execute next step | |
| "finalize": "generate_issues", # Move to final generation | |
| "error_state": "error_node" # Go to error node | |
| } | |
| ) | |
| workflow.add_edge("generate_issues", END) | |
| workflow.add_edge("error_node", END) # End after error | |
| # Compile the graph with memory (optional) | |
| # memory = MemorySaver() # Use if state needs persistence between runs | |
| # app_graph = workflow.compile(checkpointer=memory) | |
| app_graph = workflow.compile() | |
| return app_graph |