Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,837 Bytes
287959e f256208 287959e b6ee133 287959e 8f0a9cd 287959e f256208 287959e 7fca207 287959e 4a1d809 287959e f256208 287959e f256208 7fca207 f256208 287959e f256208 287959e f256208 287959e 4a1d809 287959e f256208 287959e f256208 287959e 8f0a9cd 287959e ec4377c f256208 ec4377c 8f0a9cd ec4377c f256208 f852f01 f256208 287959e f256208 287959e f256208 287959e f852f01 b85478b f2a3674 f852f01 f256208 f852f01 287959e 8f0a9cd f256208 ec4377c f852f01 f256208 ec4377c f256208 f852f01 f256208 f852f01 f256208 ec4377c f256208 b6ee133 f852f01 f256208 d049b68 ec4377c d049b68 f852f01 b85478b f2a3674 f852f01 f2a3674 f256208 f852f01 f256208 f852f01 ec4377c 8f0a9cd f256208 f852f01 f256208 f852f01 f256208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import logging
from typing import List, Dict, Any, Union, AsyncGenerator
# LangChain imports
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage
# Local imports
from .utils import getconfig, get_auth
from .prompts import system_prompt
from .sources import (
_process_context,
_build_messages,
_parse_citations,
_extract_sources,
_create_sources_list,
clean_citations
)
# Set up logger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Configuration and Model Initialization
# ---------------------------------------------------------------------
config = getconfig("params.cfg")
PROVIDER = config.get("generator", "PROVIDER")
MODEL = config.get("generator", "MODEL")
MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
INFERENCE_PROVIDER = config.get("generator", "INFERENCE_PROVIDER")
ORGANIZATION = config.get("generator", "ORGANIZATION")
# Set up authentication for the selected provider
auth_config = get_auth(PROVIDER)
def _get_chat_model():
"""Initialize the appropriate LangChain chat model based on provider"""
common_params = {"temperature": TEMPERATURE, "max_tokens": MAX_TOKENS}
providers = {
"openai": lambda: ChatOpenAI(model=MODEL, openai_api_key=auth_config["api_key"], streaming=True, **common_params),
"anthropic": lambda: ChatAnthropic(model=MODEL, anthropic_api_key=auth_config["api_key"], streaming=True, **common_params),
"cohere": lambda: ChatCohere(model=MODEL, cohere_api_key=auth_config["api_key"], streaming=True, **common_params),
"huggingface": lambda: ChatHuggingFace(llm=HuggingFaceEndpoint(
repo_id=MODEL,
huggingfacehub_api_token=auth_config["api_key"],
task="text-generation",
provider=INFERENCE_PROVIDER,
server_kwargs={"bill_to": ORGANIZATION},
temperature=TEMPERATURE,
max_new_tokens=MAX_TOKENS,
streaming=True
))
}
if PROVIDER not in providers:
raise ValueError(f"Unsupported provider: {PROVIDER}")
return providers[PROVIDER]()
# Initialize chat model
chat_model = _get_chat_model()
# ---------------------------------------------------------------------
# LLM Call Functions
# ---------------------------------------------------------------------
async def _call_llm(messages: list) -> str:
"""Provider-agnostic LLM call using LangChain (non-streaming)"""
try:
response = await chat_model.ainvoke(messages)
return response.content.strip()
except Exception as e:
logger.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
raise
async def _call_llm_streaming(messages: list) -> AsyncGenerator[str, None]:
"""Provider-agnostic streaming LLM call using LangChain"""
try:
async for chunk in chat_model.astream(messages):
if hasattr(chunk, 'content') and chunk.content:
yield chunk.content
except Exception as e:
logger.exception(f"LLM streaming failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
yield f"Error: {str(e)}"
# ---------------------------------------------------------------------
# Main Generation Functions
# ---------------------------------------------------------------------
async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]:
"""Generate an answer to a query using provided context through RAG"""
if not query.strip():
error_msg = "Query cannot be empty"
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
try:
formatted_context, processed_results = _process_context(context)
messages = _build_messages(query, formatted_context)
answer = await _call_llm(messages)
# Clean citations to ensure proper format and remove unwanted sections
answer = clean_citations(answer)
if chatui_format:
result = {"answer": answer}
if processed_results:
cited_numbers = _parse_citations(answer)
cited_sources = _extract_sources(processed_results, cited_numbers)
result["sources"] = _create_sources_list(cited_sources)
return result
else:
return answer
except Exception as e:
logger.exception("Generation failed")
error_msg = str(e)
return {"error": error_msg} if chatui_format else f"Error: {error_msg}"
async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
"""Generate a streaming answer to a query using provided context through RAG"""
if not query.strip():
error_msg = "Query cannot be empty"
if chatui_format:
yield {"event": "error", "data": {"error": error_msg}}
else:
yield f"Error: {error_msg}"
return
try:
formatted_context, processed_results = _process_context(context)
messages = _build_messages(system_prompt, query, formatted_context)
# Stream the response and accumulate for citation parsing (filter out any sources that were not cited)
accumulated_response = ""
async for chunk in _call_llm_streaming(messages):
accumulated_response += chunk
if chatui_format:
yield {"event": "data", "data": chunk}
else:
yield chunk
# Clean citations in the complete response
cleaned_response = clean_citations(accumulated_response)
# Send sources at the end if available and in ChatUI format
if chatui_format and processed_results:
cited_numbers = _parse_citations(cleaned_response)
cited_sources = _extract_sources(processed_results, cited_numbers)
sources = _create_sources_list(cited_sources)
yield {"event": "sources", "data": {"sources": sources}}
# Send END event for ChatUI format
if chatui_format:
yield {"event": "end", "data": {}}
except Exception as e:
logger.exception("Streaming generation failed")
error_msg = str(e)
if chatui_format:
yield {"event": "error", "data": {"error": error_msg}}
else:
yield f"Error: {error_msg}" |