import logging from typing import List, Dict, Any, Union, AsyncGenerator # LangChain imports from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_cohere import ChatCohere from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_core.messages import SystemMessage, HumanMessage # Local imports from .utils import getconfig, get_auth from .prompts import system_prompt from .sources import ( _process_context, _build_messages, _parse_citations, _extract_sources, _create_sources_list, clean_citations ) # Set up logger logger = logging.getLogger(__name__) # --------------------------------------------------------------------- # Configuration and Model Initialization # --------------------------------------------------------------------- config = getconfig("params.cfg") PROVIDER = config.get("generator", "PROVIDER") MODEL = config.get("generator", "MODEL") MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) TEMPERATURE = float(config.get("generator", "TEMPERATURE")) INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER") ORGANIZATION = config.get("generator", "ORGANIZATION") # Set up authentication for the selected provider auth_config = get_auth(PROVIDER) def _get_chat_model(): """Initialize the appropriate LangChain chat model based on provider""" common_params = {"temperature": TEMPERATURE, "max_tokens": MAX_TOKENS} providers = { "openai": lambda: ChatOpenAI(model=MODEL, openai_api_key=auth_config["api_key"], streaming=True, **common_params), "anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params), "cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params), "huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint( repo_id=MODEL, huggingfacehub_api_token=auth_config["api_key"], task="text-generation", provider=INFERENCE_PROVIDER, server_kwargs={"bill_to": ORGANIZATION}, temperature=TEMPERATURE, max_new_tokens=MAX_TOKENS, streaming=True )) } if PROVIDER not in providers: raise ValueError(f"Unsupported provider: {PROVIDER}") return providers[PROVIDER]() # Initialize chat model chat_model = _get_chat_model() # --------------------------------------------------------------------- # LLM Call Functions # --------------------------------------------------------------------- async def _call_llm(messages: list) -> str: """Provider-agnostic LLM call using LangChain (non-streaming)""" try: response = await chat_model.ainvoke(messages) return response.content.strip() except Exception as e: logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") raise async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]: """Provider-agnostic streaming LLM call using LangChain""" try: async for chunk in chat_model.astream(messages): if hasattr(chunk, 'content') and chunk.content: yield chunk.content except Exception as e: logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}") yield f"Error: {str(e)}" # --------------------------------------------------------------------- # Main Generation Functions # --------------------------------------------------------------------- async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]: """Generate an answer to a query using provided context through RAG""" if not query.strip(): error_msg = "Query cannot be empty" return {"error": error_msg} if chatui_format else f"Error: {error_msg}" try: formatted_context, processed_results = _process_context(context) messages = _build_messages(query, formatted_context) answer = await _call_llm(messages) # Clean citations to ensure proper format and remove unwanted sections answer = clean_citations(answer) if chatui_format: result = {"answer": answer} if processed_results: cited_numbers = _parse_citations(answer) cited_sources = _extract_sources(processed_results, cited_numbers) result["sources"] = _create_sources_list(cited_sources) return result else: return answer except Exception as e: logger.exception("Generation failed") error_msg = str(e) return {"error": error_msg} if chatui_format else f"Error: {error_msg}" async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: """Generate a streaming answer to a query using provided context through RAG""" if not query.strip(): error_msg = "Query cannot be empty" if chatui_format: yield {"event": "error", "data": {"error": error_msg}} else: yield f"Error: {error_msg}" return try: formatted_context, processed_results = _process_context(context) messages = _build_messages(system_prompt, query, formatted_context) # Stream the response and accumulate for citation parsing (filter out any sources that were not cited) accumulated_response = "" async for chunk in _call_llm_streaming(messages): accumulated_response += chunk if chatui_format: yield {"event": "data", "data": chunk} else: yield chunk # Clean citations in the complete response cleaned_response = clean_citations(accumulated_response) # Send sources at the end if available and in ChatUI format if chatui_format and processed_results: cited_numbers = _parse_citations(cleaned_response) cited_sources = _extract_sources(processed_results, cited_numbers) sources = _create_sources_list(cited_sources) yield {"event": "sources", "data": {"sources": sources}} # Send END event for ChatUI format if chatui_format: yield {"event": "end", "data": {}} except Exception as e: logger.exception("Streaming generation failed") error_msg = str(e) if chatui_format: yield {"event": "error", "data": {"error": error_msg}} else: yield f"Error: {error_msg}"