|
|
import os |
|
|
import weaviate |
|
|
from weaviate.classes.init import Auth |
|
|
from dotenv import load_dotenv |
|
|
import google.genai as genai |
|
|
import logging |
|
|
from configs import load_yaml_config |
|
|
from groq import Groq |
|
|
from openai import OpenAI |
|
|
from datetime import datetime |
|
|
import requests |
|
|
from prompt_template import Prompt_template_LLM_Generation |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("clients") |
|
|
except ImportError: |
|
|
|
|
|
logger = logging.getLogger("clients") |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
config = load_yaml_config("config.yaml") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GROQ_URL = os.environ["GROQ_URL"] |
|
|
GROQ_API_TOKEN= os.environ["GROQ_API_TOKEN"] |
|
|
|
|
|
|
|
|
|
|
|
DEEPINFRA_API_KEY = os.environ["DEEPINFRA_API_KEY"] |
|
|
DEEPINFRA_URL = os.environ["DEEPINFRA_URL"] |
|
|
DEEPINFRA_EMBEDDING_URL = os.environ["DEEPINFRA_EMBEDDING_URL"] |
|
|
DEEPINFRA_RERANK_URL = os.environ["DEEPINFRA_RERANK_URL"] |
|
|
|
|
|
|
|
|
def init_weaviate_client(): |
|
|
url = os.getenv("WEAVIATE_URL") |
|
|
api_key = os.getenv("WEAVIATE_API_KEY") |
|
|
if not url or not api_key: |
|
|
logger.warning( |
|
|
"Weaviate credentials missing (WEAVIATE_URL/WEAVIATE_API_KEY).") |
|
|
return None |
|
|
|
|
|
logger.info("Attempting to connect to Weaviate cloud...") |
|
|
client = weaviate.connect_to_weaviate_cloud( |
|
|
cluster_url=url, |
|
|
auth_credentials=Auth.api_key(api_key), |
|
|
skip_init_checks=True) |
|
|
|
|
|
if client is None: |
|
|
logger.error(f"failed to init client...") |
|
|
logger.info("Successfully connected to Weaviate cloud.") |
|
|
return client |
|
|
|
|
|
|
|
|
def get_weaviate_client(): |
|
|
if not hasattr(get_weaviate_client, '_client'): |
|
|
get_weaviate_client._client = init_weaviate_client() |
|
|
return get_weaviate_client._client |
|
|
|
|
|
|
|
|
def close_weaviate_client(): |
|
|
"""Close the Weaviate client connection if it exists.""" |
|
|
if hasattr(get_weaviate_client, '_client') and get_weaviate_client._client: |
|
|
get_weaviate_client._client.close() |
|
|
delattr(get_weaviate_client, '_client') |
|
|
|
|
|
|
|
|
def gemini_client(): |
|
|
return genai.Client(api_key=os.environ["GEMINI_API_KEY"]) |
|
|
|
|
|
def groq_client(): |
|
|
return Groq(api_key=os.environ.get("GROQ_API_KEY"),) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def groq_qwen_generate_content(prompt: str) -> str: |
|
|
"""Streaming chat completion with Qwen on SiliconFlow via OpenAI client.""" |
|
|
if not (GROQ_URL and GROQ_API_TOKEN): |
|
|
logger.error("GROQ_URL or GROQ_API_TOKEN not configured.") |
|
|
return "" |
|
|
|
|
|
client = OpenAI(base_url=GROQ_URL, api_key=GROQ_API_TOKEN) |
|
|
if client is None: |
|
|
logger.error("Failed to initialize Groq client.") |
|
|
return "" |
|
|
else: |
|
|
logger.info("Successfully initialized Groq client.") |
|
|
|
|
|
logger.info("Calling openai/gpt-oss-120b for generation from Groq") |
|
|
|
|
|
output = "" |
|
|
response = client.chat.completions.create( |
|
|
|
|
|
model = config["apis_models"]["groq"]["openai"]["gpt_oss"], |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
stream=True,) |
|
|
|
|
|
for chunk in response: |
|
|
if not getattr(chunk, "choices", None): |
|
|
continue |
|
|
delta = chunk.choices[0].delta |
|
|
if getattr(delta, "content", None): |
|
|
output += delta.content |
|
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content: |
|
|
output += delta.reasoning_content |
|
|
logger.info("Successfully generated content with Qwen") |
|
|
return output.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def siliconflow_qwen_generate_content(prompt: str) -> str: |
|
|
"""Streaming chat completion with Qwen on SiliconFlow via OpenAI client.""" |
|
|
if not (os.environ['SILICONFLOW_URL'] and os.environ['SILICONFLOW_API_KEY']): |
|
|
logger.error("SILICONFLOW_URL or SILICONFLOW_API_KEY not configured.") |
|
|
return "" |
|
|
client = OpenAI(base_url=os.environ['SILICONFLOW_URL'], api_key=os.environ['SILICONFLOW_API_KEY']) |
|
|
if client is None: |
|
|
logger.error("Failed to initialize SiliconFlow client.") |
|
|
return "" |
|
|
else: |
|
|
logger.info("Successfully initialized SiliconFlow client.") |
|
|
logger.info("Calling Qwen/Qwen3-30B-Instruct for generation...") |
|
|
output = "" |
|
|
logger.info(f"{config['apis_models']['silicon_flow']['qwen']['chat3_30b']}") |
|
|
response = client.chat.completions.create( |
|
|
model=config["apis_models"]["silicon_flow"]["qwen"]["chat3_30b"], |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
stream=True) |
|
|
for chunk in response: |
|
|
if not getattr(chunk, "choices", None): |
|
|
continue |
|
|
delta = chunk.choices[0].delta |
|
|
if getattr(delta, "content", None): |
|
|
output += delta.content |
|
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content: |
|
|
output += delta.reasoning_content |
|
|
logger.info("Successfully generated content with Qwen") |
|
|
return output.strip() |
|
|
|
|
|
|
|
|
|
|
|
def deepinfra_qwen_generate_content(prompt: str) -> str: |
|
|
"""Streaming chat completion with Qwen on SiliconFlow via OpenAI client.""" |
|
|
if not (DEEPINFRA_URL and DEEPINFRA_API_KEY): |
|
|
logger.error("GROQ_URL or GROQ_API_TOKEN not configured.") |
|
|
return "" |
|
|
|
|
|
client = OpenAI(base_url=DEEPINFRA_URL, api_key=DEEPINFRA_API_KEY) |
|
|
if client is None: |
|
|
logger.error("Failed to initialize Groq client.") |
|
|
return "" |
|
|
else: |
|
|
logger.info("Successfully initialized Groq client.") |
|
|
|
|
|
logger.info("Calling openai gpt-oss-120b for generation from DeepInfra...") |
|
|
output = "" |
|
|
response = client.chat.completions.create( |
|
|
|
|
|
model = config["apis_models"]["groq"]["openai"]["gpt_oss"], |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=1, |
|
|
max_completion_tokens=8192, |
|
|
top_p=1, |
|
|
reasoning_effort="low", |
|
|
stream=True, |
|
|
tools=[{"type":"browser_search"}]) |
|
|
|
|
|
|
|
|
for chunk in response: |
|
|
if not getattr(chunk, "choices", None): |
|
|
continue |
|
|
delta = chunk.choices[0].delta |
|
|
if getattr(delta, "content", None): |
|
|
output += delta.content |
|
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content: |
|
|
output += delta.reasoning_content |
|
|
logger.info("Successfully generated content with Qwen") |
|
|
return output.strip() |
|
|
|
|
|
|
|
|
|
|
|
def deepinfra_embedding(texts: list[str], batch_size: int = 50) -> list[list[float]]: |
|
|
all_embeddings = [] |
|
|
headers = { |
|
|
"Authorization": f"Bearer {DEEPINFRA_API_KEY}", |
|
|
"Content-Type": "application/json"} |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i:i + batch_size] |
|
|
payload = { |
|
|
"model": config["apis_models"]["deepinfra"]["qwen"]["embed"], |
|
|
"input": batch} |
|
|
try: |
|
|
response = requests.post( |
|
|
DEEPINFRA_EMBEDDING_URL, json=payload, headers=headers) |
|
|
|
|
|
if response.status_code != 200: |
|
|
logger.error(f"DeepInfra API error {response.status_code}: {response.text}") |
|
|
|
|
|
all_embeddings.extend([[] for _ in batch]) |
|
|
continue |
|
|
data = response.json() |
|
|
|
|
|
if "detail" in data and "error" in data["detail"]: |
|
|
logger.error(f"DeepInfra API error: {data['detail']['error']}") |
|
|
|
|
|
all_embeddings.extend([[] for _ in batch]) |
|
|
continue |
|
|
if "data" not in data: |
|
|
logger.error(f"Invalid response format: {data}") |
|
|
|
|
|
all_embeddings.extend([[] for _ in batch]) |
|
|
continue |
|
|
batch_embs = [item["embedding"] for item in data["data"]] |
|
|
all_embeddings.extend(batch_embs) |
|
|
except requests.RequestException as e: |
|
|
logger.error(f"Request failed: {e}") |
|
|
|
|
|
all_embeddings.extend([[] for _ in batch]) |
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
|
|
|
def deepinfra_rerank(batch: list[str], items_to_rerank: list[str]) -> list[str]: |
|
|
payload = { |
|
|
"model": config["apis_models"]["deepinfra"]["qwen"]["rerank"], |
|
|
"input": batch} |
|
|
headers = { |
|
|
"Authorization": f"Bearer {DEEPINFRA_API_KEY}", |
|
|
"Content-Type": "application/json"} |
|
|
r = requests.post( |
|
|
DEEPINFRA_RERANK_URL, |
|
|
json=payload, |
|
|
headers=headers, |
|
|
timeout=60,) |
|
|
if r.ok: |
|
|
rerank_data = r.json() |
|
|
ranked_docs = sorted( |
|
|
zip(rerank_data.get("results", []), items_to_rerank), |
|
|
key=lambda x: x[0].get("relevance_score", 0), |
|
|
reverse=True) |
|
|
reranked = ranked_docs[0][1] if ranked_docs else batch |
|
|
return reranked |
|
|
else: |
|
|
return batch |
|
|
|
|
|
def deepinfra_client(): |
|
|
return OpenAI(api_key=os.environ["DEEPINFRA_API_KEY"], base_url=os.environ["DEEPINFRA_URL"],) |
|
|
|
|
|
|
|
|
def qwen_generate(prompt: str) -> str: |
|
|
"""Streaming chat completion with Qwen on SiliconFlow and Groq via OpenAI client.""" |
|
|
if config["apis_models"]["num"] == 1: |
|
|
return siliconflow_qwen_generate_content(prompt) |
|
|
else: |
|
|
return groq_qwen_generate_content(prompt) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_prompt = Prompt_template_LLM_Generation.format( |
|
|
new_query="what is autism") |
|
|
logger.info(f"groq qwen generate.....: {groq_qwen_generate_content(gen_prompt)}") |
|
|
|
|
|
print(f"=" * 50) |
|
|
response = siliconflow_qwen_generate_content("what is autism") |
|
|
logger.info(f"siliconflow qwen response: {response}") |
|
|
|
|
|
print(f"=" * 50) |
|
|
|
|
|
|
|
|
response = deepinfra_embedding(["what is autism"], 1) |
|
|
if response and response[0]: |
|
|
logger.info(f"deepinfra embedding response: {response}") |
|
|
else: |
|
|
raise ValueError("Empty embeddings returned") |
|
|
|
|
|
|
|
|
print(f"=" * 50) |
|
|
response = deepinfra_rerank(["what is autism"], ["what is autism"]) |
|
|
logger.info(f"deepinfra rerank response: {response}") |