Spaces:
Runtime error
Runtime error
Commit
Β·
a91b925
1
Parent(s):
e15c8b9
strict and then lenient
Browse files
app.py
CHANGED
|
@@ -151,18 +151,11 @@ st.markdown("""
|
|
| 151 |
|
| 152 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 153 |
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
| 154 |
-
strict_mode = st.radio(
|
| 155 |
-
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
| 156 |
-
('lenient', 'strict'))
|
| 157 |
use_reranking = st.radio(
|
| 158 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 159 |
('yes', 'no'))
|
| 160 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300,
|
| 161 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
| 162 |
-
use_query_exp = st.radio(
|
| 163 |
-
"(Experimental) use query expansion? Right now it just recommends queries",
|
| 164 |
-
('yes', 'no'))
|
| 165 |
-
suggested_queries = st.slider('Number of suggested queries to use', 0, 10, 5)
|
| 166 |
|
| 167 |
# def paraphrase(text, max_length=128):
|
| 168 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
@@ -180,7 +173,14 @@ def run_query(query):
|
|
| 180 |
# """)
|
| 181 |
limit = top_hits_limit or 100
|
| 182 |
context_limit = context_lim or 10
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 185 |
return st.markdown("""
|
| 186 |
<div class="container-fluid">
|
|
@@ -197,8 +197,7 @@ def run_query(query):
|
|
| 197 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 198 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 199 |
context = '\n'.join(sorted_contexts[:context_limit])
|
| 200 |
-
|
| 201 |
-
context = '\n'.join(contexts[:context_limit])
|
| 202 |
results = []
|
| 203 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 204 |
for result in model_results:
|
|
|
|
| 151 |
|
| 152 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 153 |
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
|
|
|
|
|
|
|
|
|
| 154 |
use_reranking = st.radio(
|
| 155 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 156 |
('yes', 'no'))
|
| 157 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
| 158 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# def paraphrase(text, max_length=128):
|
| 161 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
|
| 173 |
# """)
|
| 174 |
limit = top_hits_limit or 100
|
| 175 |
context_limit = context_lim or 10
|
| 176 |
+
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True)
|
| 177 |
+
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
|
| 178 |
+
|
| 179 |
+
contexts = list(
|
| 180 |
+
set(contexts_strict + contexts_lenient)
|
| 181 |
+
)
|
| 182 |
+
orig_docs = orig_docs_strict + orig_docs_lenient
|
| 183 |
+
|
| 184 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 185 |
return st.markdown("""
|
| 186 |
<div class="container-fluid">
|
|
|
|
| 197 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 198 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 199 |
context = '\n'.join(sorted_contexts[:context_limit])
|
| 200 |
+
|
|
|
|
| 201 |
results = []
|
| 202 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 203 |
for result in model_results:
|