Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import uuid # for generating thread IDs for checkpointer | |
| from typing import AsyncIterator, Optional, TypedDict | |
| from dotenv import find_dotenv, load_dotenv | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import END, START, StateGraph | |
| from smolagents import CodeAgent, LiteLLMModel | |
| from smolagents.memory import ActionStep, FinalAnswerStep | |
| from smolagents.monitoring import LogLevel | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv(find_dotenv()) | |
| # Get required environment variables with validation | |
| API_BASE = os.getenv("API_BASE") | |
| API_KEY = os.getenv("API_KEY") | |
| MODEL_ID = os.getenv("MODEL_ID") | |
| if not all([API_BASE, API_KEY, MODEL_ID]): | |
| raise ValueError( | |
| "Missing required environment variables: API_BASE, API_KEY, MODEL_ID" | |
| ) | |
| # Define the state types for our graph | |
| class AgentState(TypedDict): | |
| task: str | |
| current_step: Optional[dict] # Store serializable dict instead of ActionStep | |
| error: Optional[str] | |
| answer_text: Optional[str] | |
| # Initialize model with error handling | |
| try: | |
| model = LiteLLMModel( | |
| api_base=API_BASE, | |
| api_key=API_KEY, | |
| model_id=MODEL_ID, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize model: {str(e)}") | |
| raise | |
| # Initialize agent with error handling | |
| try: | |
| agent = CodeAgent( | |
| add_base_tools=True, | |
| additional_authorized_imports=["pandas", "numpy"], | |
| max_steps=10, | |
| model=model, | |
| tools=[], | |
| step_callbacks=None, | |
| verbosity_level=LogLevel.ERROR, | |
| ) | |
| agent.logger.console.width = 66 | |
| except Exception as e: | |
| logger.error(f"Failed to initialize agent: {str(e)}") | |
| raise | |
| async def process_step(state: AgentState) -> AgentState: | |
| """Process a single step of the agent's execution.""" | |
| try: | |
| # Clear previous step results before running agent.run | |
| state["current_step"] = None | |
| state["answer_text"] = None | |
| state["error"] = None | |
| steps = agent.run( | |
| task=state["task"], | |
| additional_args=None, | |
| images=None, | |
| max_steps=1, # Process one step at a time | |
| stream=True, | |
| reset=False, # Maintain agent's internal state across process_step calls | |
| ) | |
| for step in steps: | |
| if isinstance(step, ActionStep): | |
| # Convert ActionStep to serializable dict using the correct attributes | |
| state["current_step"] = { | |
| "step_number": step.step_number, | |
| "model_output": step.model_output, | |
| "observations": step.observations, | |
| "tool_calls": [ | |
| {"name": tc.name, "arguments": tc.arguments} | |
| for tc in (step.tool_calls or []) | |
| ], | |
| "action_output": step.action_output, | |
| } | |
| logger.info(f"Processed action step {step.step_number}") | |
| elif isinstance(step, FinalAnswerStep): | |
| state["answer_text"] = step.final_answer | |
| logger.info("Processed final answer") | |
| logger.debug(f"Final answer details: {step}") | |
| logger.info(f"Extracted answer text: {state['answer_text']}") | |
| # Return immediately when we get a final answer | |
| return state | |
| # If loop finishes without FinalAnswerStep, return current state | |
| return state | |
| except Exception as e: | |
| state["error"] = str(e) | |
| logger.error(f"Error during agent execution step: {str(e)}") | |
| return state | |
| def should_continue(state: AgentState) -> bool: | |
| """Determine if the agent should continue processing steps.""" | |
| # Continue if we don't have an answer_text and no error | |
| continue_execution = state.get("answer_text") is None and state.get("error") is None | |
| logger.debug( | |
| f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}" | |
| ) | |
| return continue_execution | |
| # Build the LangGraph graph once with persistence | |
| memory = MemorySaver() | |
| builder = StateGraph(AgentState) | |
| builder.add_node("process_step", process_step) | |
| builder.add_edge(START, "process_step") | |
| builder.add_conditional_edges( | |
| "process_step", should_continue, {True: "process_step", False: END} | |
| ) | |
| graph = builder.compile(checkpointer=memory) | |
| async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]: | |
| """Stream the execution of the agent.""" | |
| if not task: | |
| raise ValueError("Task cannot be empty") | |
| logger.info(f"Initializing agent execution for task: {task}") | |
| # Initialize the state | |
| initial_state: AgentState = { | |
| "task": task, | |
| "current_step": None, | |
| "error": None, | |
| "answer_text": None, | |
| } | |
| # Pass thread_id via the config dict so the checkpointer can persist state | |
| async for state in graph.astream( | |
| initial_state, {"configurable": {"thread_id": thread_id}} | |
| ): | |
| yield state | |
| # Propagate error immediately if it occurs without an answer | |
| if state.get("error") and not state.get("answer_text"): | |
| logger.error(f"Propagating error from stream: {state['error']}") | |
| raise Exception(state["error"]) | |
| async def run_with_streaming(task: str, thread_id: str) -> dict: | |
| """Run the agent with streaming output and return the results.""" | |
| last_state = None | |
| steps = [] | |
| error = None | |
| final_answer_text = None | |
| try: | |
| logger.info(f"Starting execution run for task: {task}") | |
| async for state in stream_execution(task, thread_id): | |
| last_state = state | |
| if current_step := state.get("current_step"): | |
| if not steps or steps[-1]["step_number"] != current_step["step_number"]: | |
| steps.append(current_step) | |
| # Keep print here for direct user feedback during streaming | |
| print(f"\nStep {current_step['step_number']}:") | |
| print(f"Model Output: {current_step['model_output']}") | |
| print(f"Observations: {current_step['observations']}") | |
| if current_step.get("tool_calls"): | |
| print("Tool Calls:") | |
| for tc in current_step["tool_calls"]: | |
| print(f" - {tc['name']}: {tc['arguments']}") | |
| if current_step.get("action_output"): | |
| print(f"Action Output: {current_step['action_output']}") | |
| # After the stream is finished, process the last state | |
| logger.info("Stream finished.") | |
| if last_state: | |
| # LangGraph streams dicts where keys are node names, values are state dicts | |
| node_name = list(last_state.keys())[0] | |
| actual_state = last_state.get(node_name) | |
| if actual_state: | |
| final_answer_text = actual_state.get("answer_text") | |
| error = actual_state.get("error") | |
| logger.info( | |
| f"Final answer text extracted from last state: {final_answer_text}" | |
| ) | |
| logger.info(f"Error extracted from last state: {error}") | |
| # Ensure steps list is consistent with the final state if needed | |
| last_step_in_state = actual_state.get("current_step") | |
| if last_step_in_state and ( | |
| not steps | |
| or steps[-1]["step_number"] != last_step_in_state["step_number"] | |
| ): | |
| logger.debug("Adding last step from final state to steps list.") | |
| steps.append(last_step_in_state) | |
| else: | |
| logger.warning( | |
| "Could not find actual state dictionary within last_state." | |
| ) | |
| return {"steps": steps, "final_answer": final_answer_text, "error": error} | |
| except Exception as e: | |
| import traceback | |
| logger.error( | |
| f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}" | |
| ) | |
| # Attempt to return based on the last known state even if exception occurred outside stream | |
| final_answer_text = None | |
| error_msg = str(e) | |
| if last_state: | |
| node_name = list(last_state.keys())[0] | |
| actual_state = last_state.get(node_name) | |
| if actual_state: | |
| final_answer_text = actual_state.get("answer_text") | |
| return {"steps": steps, "final_answer": final_answer_text, "error": error_msg} | |
| if __name__ == "__main__": | |
| import asyncio | |
| import uuid | |
| # Example Usage | |
| task_to_run = "What is the capital of France?" | |
| thread_id = str(uuid.uuid4()) # Generate a unique thread ID for this run | |
| logger.info( | |
| f"Starting agent run from __main__ for task: '{task_to_run}' with thread_id: {thread_id}" | |
| ) | |
| result = asyncio.run(run_with_streaming(task_to_run, thread_id)) | |
| logger.info("Agent run finished.") | |
| # Print final results | |
| print("\n--- Execution Results ---") | |
| print(f"Number of Steps: {len(result.get('steps', []))}") | |
| # Optionally print step details | |
| # for i, step in enumerate(result.get('steps', [])): | |
| # print(f"Step {i+1} Details: {step}") | |
| print(f"Final Answer: {result.get('final_answer') or 'Not found'}") | |
| if err := result.get("error"): | |
| print(f"Error: {err}") | |