Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import logging | |
| from typing import List, Dict, Any, Union, AsyncGenerator | |
| # LangChain imports | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_cohere import ChatCohere | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| # Local imports | |
| from .utils import getconfig, get_auth | |
| from .prompts import system_prompt | |
| from .sources import ( | |
| _process_context, | |
| _build_messages, | |
| _parse_citations, | |
| _extract_sources, | |
| _create_sources_list, | |
| clean_citations | |
| ) | |
| # Set up logger | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------- | |
| # Configuration and Model Initialization | |
| # --------------------------------------------------------------------- | |
| config = getconfig("params.cfg") | |
| PROVIDER = config.get("generator", "PROVIDER") | |
| MODEL = config.get("generator", "MODEL") | |
| MAX_TOKENS = int(config.get("generator", "MAX_TOKENS")) | |
| TEMPERATURE = float(config.get("generator", "TEMPERATURE")) | |
| INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER") | |
| ORGANIZATION = config.get("generator", "ORGANIZATION") | |
| # Set up authentication for the selected provider | |
| auth_config = get_auth(PROVIDER) | |
| def _get_chat_model(): | |
| """Initialize the appropriate LangChain chat model based on provider""" | |
| common_params = {"temperature": TEMPERATURE, "max_tokens": MAX_TOKENS} | |
| providers = { | |
| "openai": lambda: ChatOpenAI(model=MODEL, openai_api_key=auth_config["api_key"], streaming=True, **common_params), | |
| "anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params), | |
| "cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params), | |
| "huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint( | |
| repo_id=MODEL, | |
| huggingfacehub_api_token=auth_config["api_key"], | |
| task="text-generation", | |
| provider=INFERENCE_PROVIDER, | |
| server_kwargs={"bill_to": ORGANIZATION}, | |
| temperature=TEMPERATURE, | |
| max_new_tokens=MAX_TOKENS, | |
| streaming=True | |
| )) | |
| } | |
| if PROVIDER not in providers: | |
| raise ValueError(f"Unsupported provider: {PROVIDER}") | |
| return providers[PROVIDER]() | |
| # Initialize chat model | |
| chat_model = _get_chat_model() | |
| # --------------------------------------------------------------------- | |
| # LLM Call Functions | |
| # --------------------------------------------------------------------- | |
| async def _call_llm(messages: list) -> str: | |
| """Provider-agnostic LLM call using LangChain (non-streaming)""" | |
| try: | |
| response = await chat_model.ainvoke(messages) | |
| return response.content.strip() | |
| except Exception as e: | |
| logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
| raise | |
| async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]: | |
| """Provider-agnostic streaming LLM call using LangChain""" | |
| try: | |
| async for chunk in chat_model.astream(messages): | |
| if hasattr(chunk, 'content') and chunk.content: | |
| yield chunk.content | |
| except Exception as e: | |
| logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}") | |
| yield f"Error: {str(e)}" | |
| # --------------------------------------------------------------------- | |
| # Main Generation Functions | |
| # --------------------------------------------------------------------- | |
| async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]: | |
| """Generate an answer to a query using provided context through RAG""" | |
| if not query.strip(): | |
| error_msg = "Query cannot be empty" | |
| return {"error": error_msg} if chatui_format else f"Error: {error_msg}" | |
| try: | |
| formatted_context, processed_results = _process_context(context) | |
| messages = _build_messages(query, formatted_context) | |
| answer = await _call_llm(messages) | |
| # Clean citations to ensure proper format and remove unwanted sections | |
| answer = clean_citations(answer) | |
| if chatui_format: | |
| result = {"answer": answer} | |
| if processed_results: | |
| cited_numbers = _parse_citations(answer) | |
| cited_sources = _extract_sources(processed_results, cited_numbers) | |
| result["sources"] = _create_sources_list(cited_sources) | |
| return result | |
| else: | |
| return answer | |
| except Exception as e: | |
| logger.exception("Generation failed") | |
| error_msg = str(e) | |
| return {"error": error_msg} if chatui_format else f"Error: {error_msg}" | |
| async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: | |
| """Generate a streaming answer to a query using provided context through RAG""" | |
| if not query.strip(): | |
| error_msg = "Query cannot be empty" | |
| if chatui_format: | |
| yield {"event": "error", "data": {"error": error_msg}} | |
| else: | |
| yield f"Error: {error_msg}" | |
| return | |
| try: | |
| formatted_context, processed_results = _process_context(context) | |
| messages = _build_messages(system_prompt, query, formatted_context) | |
| # Stream the response and accumulate for citation parsing (filter out any sources that were not cited) | |
| accumulated_response = "" | |
| async for chunk in _call_llm_streaming(messages): | |
| accumulated_response += chunk | |
| if chatui_format: | |
| yield {"event": "data", "data": chunk} | |
| else: | |
| yield chunk | |
| # Clean citations in the complete response | |
| cleaned_response = clean_citations(accumulated_response) | |
| # Send sources at the end if available and in ChatUI format | |
| if chatui_format and processed_results: | |
| cited_numbers = _parse_citations(cleaned_response) | |
| cited_sources = _extract_sources(processed_results, cited_numbers) | |
| sources = _create_sources_list(cited_sources) | |
| yield {"event": "sources", "data": {"sources": sources}} | |
| # Send END event for ChatUI format | |
| if chatui_format: | |
| yield {"event": "end", "data": {}} | |
| except Exception as e: | |
| logger.exception("Streaming generation failed") | |
| error_msg = str(e) | |
| if chatui_format: | |
| yield {"event": "error", "data": {"error": error_msg}} | |
| else: | |
| yield f"Error: {error_msg}" |