Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -434,14 +434,26 @@ def optimize_vocabulary(texts, vocab_size=10000, min_frequency=2):
|
|
| 434 |
return tokenizer, optimized_texts
|
| 435 |
|
| 436 |
# New preprocessing function
|
| 437 |
-
def optimize_query(query,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
multi_query_retriever = MultiQueryRetriever.from_llm(
|
| 439 |
-
retriever=
|
| 440 |
llm=llm
|
| 441 |
)
|
| 442 |
optimized_queries = multi_query_retriever.generate_queries(query)
|
| 443 |
return optimized_queries
|
| 444 |
-
|
| 445 |
# New postprocessing function
|
| 446 |
def rerank_results(results, query, reranker):
|
| 447 |
reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
|
|
@@ -495,7 +507,7 @@ def compare_embeddings(file, query, embedding_models, custom_embedding_model, sp
|
|
| 495 |
chunks = optimized_chunks
|
| 496 |
|
| 497 |
if use_query_optimization:
|
| 498 |
-
optimized_queries = optimize_query(query, query_optimization_model)
|
| 499 |
query = " ".join(optimized_queries)
|
| 500 |
|
| 501 |
results, search_time, vector_store, results_raw = search_embeddings(
|
|
|
|
| 434 |
return tokenizer, optimized_texts
|
| 435 |
|
| 436 |
# New preprocessing function
|
| 437 |
+
def optimize_query(query, llm_model, chunks, embedding_model, vector_store_type, search_type, top_k):
|
| 438 |
+
llm = HuggingFacePipeline.from_model_id(
|
| 439 |
+
model_id=llm_model,
|
| 440 |
+
task="text2text-generation",
|
| 441 |
+
model_kwargs={"do_sample": True, "temperature": 0, "max_new_tokens": 64},
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Create a temporary vector store for query optimization
|
| 445 |
+
temp_vector_store = get_vector_store(vector_store_type, chunks, embedding_model)
|
| 446 |
+
|
| 447 |
+
# Create a retriever with the temporary vector store
|
| 448 |
+
temp_retriever = get_retriever(temp_vector_store, search_type, {"k": top_k})
|
| 449 |
+
|
| 450 |
multi_query_retriever = MultiQueryRetriever.from_llm(
|
| 451 |
+
retriever=temp_retriever,
|
| 452 |
llm=llm
|
| 453 |
)
|
| 454 |
optimized_queries = multi_query_retriever.generate_queries(query)
|
| 455 |
return optimized_queries
|
| 456 |
+
|
| 457 |
# New postprocessing function
|
| 458 |
def rerank_results(results, query, reranker):
|
| 459 |
reranked_results = reranker.rerank(query, [doc.page_content for doc in results])
|
|
|
|
| 507 |
chunks = optimized_chunks
|
| 508 |
|
| 509 |
if use_query_optimization:
|
| 510 |
+
optimized_queries = optimize_query(query, query_optimization_model, chunks, embedding_model, vector_store_type, search_type, top_k)
|
| 511 |
query = " ".join(optimized_queries)
|
| 512 |
|
| 513 |
results, search_time, vector_store, results_raw = search_embeddings(
|