mtyrrell's picture
added novita for HF inference provider
7fca207
import logging
from typing import List, Dict, Any, Union, AsyncGenerator
# LangChain imports
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage
# Local imports
from .utils import getconfig, get_auth
from .prompts import system_prompt
from .sources import (
_process_context,
_build_messages,
_parse_citations,
_extract_sources,
_create_sources_list,
clean_citations
)
# Set up logger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Configuration and Model Initialization
# ---------------------------------------------------------------------
config = getconfig("params.cfg")
PROVIDER = config.get("generator", "PROVIDER")
MODEL = config.get("generator", "MODEL")
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
ORGANIZATION = config.get("generator", "ORGANIZATION")
# Set up authentication for the selected provider
auth_config = get_auth(PROVIDER)
def _get_chat_model():
"""Initialize the appropriate LangChain chat model based on provider"""
common_params = {"temperature": TEMPERATURE, "max_tokens": MAX_TOKENS}
providers = {
"openai": lambda: ChatOpenAI(model=MODEL, openai_api_key=auth_config["api_key"], streaming=True, **common_params),
"anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params),
"cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params),
"huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint(
repo_id=MODEL,
huggingfacehub_api_token=auth_config["api_key"],
task="text-generation",
provider=INFERENCE_PROVIDER,
server_kwargs={"bill_to": ORGANIZATION},
temperature=TEMPERATURE,
max_new_tokens=MAX_TOKENS,
streaming=True
))
}
if PROVIDER not in providers:
raise ValueError(f"Unsupported provider: {PROVIDER}")
return providers[PROVIDER]()
# Initialize chat model
chat_model = _get_chat_model()
# ---------------------------------------------------------------------
# LLM Call Functions
# ---------------------------------------------------------------------
async def _call_llm(messages: list) -> str:
"""Provider-agnostic LLM call using LangChain (non-streaming)"""
try:
response = await chat_model.ainvoke(messages)
return response.content.strip()
except Exception as e:
logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
raise
async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
"""Provider-agnostic streaming LLM call using LangChain"""
try:
async for chunk in chat_model.astream(messages):
if hasattr(chunk, 'content') and chunk.content:
yield chunk.content
except Exception as e:
logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
yield f"Error: {str(e)}"
# ---------------------------------------------------------------------
# Main Generation Functions
# ---------------------------------------------------------------------
async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]:
"""Generate an answer to a query using provided context through RAG"""
if not query.strip():
error_msg = "Query cannot be empty"
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
try:
formatted_context, processed_results = _process_context(context)
messages = _build_messages(query, formatted_context)
answer = await _call_llm(messages)
# Clean citations to ensure proper format and remove unwanted sections
answer = clean_citations(answer)
if chatui_format:
result = {"answer": answer}
if processed_results:
cited_numbers = _parse_citations(answer)
cited_sources = _extract_sources(processed_results, cited_numbers)
result["sources"] = _create_sources_list(cited_sources)
return result
else:
return answer
except Exception as e:
logger.exception("Generation failed")
error_msg = str(e)
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
"""Generate a streaming answer to a query using provided context through RAG"""
if not query.strip():
error_msg = "Query cannot be empty"
if chatui_format:
yield {"event": "error", "data": {"error": error_msg}}
else:
yield f"Error: {error_msg}"
return
try:
formatted_context, processed_results = _process_context(context)
messages = _build_messages(system_prompt, query, formatted_context)
# Stream the response and accumulate for citation parsing (filter out any sources that were not cited)
accumulated_response = ""
async for chunk in _call_llm_streaming(messages):
accumulated_response += chunk
if chatui_format:
yield {"event": "data", "data": chunk}
else:
yield chunk
# Clean citations in the complete response
cleaned_response = clean_citations(accumulated_response)
# Send sources at the end if available and in ChatUI format
if chatui_format and processed_results:
cited_numbers = _parse_citations(cleaned_response)
cited_sources = _extract_sources(processed_results, cited_numbers)
sources = _create_sources_list(cited_sources)
yield {"event": "sources", "data": {"sources": sources}}
# Send END event for ChatUI format
if chatui_format:
yield {"event": "end", "data": {}}
except Exception as e:
logger.exception("Streaming generation failed")
error_msg = str(e)
if chatui_format:
yield {"event": "error", "data": {"error": error_msg}}
else:
yield f"Error: {error_msg}"