Spaces:
Runtime error
Runtime error
Commit
Β·
69d7ac6
1
Parent(s):
4c36cd4
add ability to specify strict or lenient
Browse files
app.py
CHANGED
|
@@ -41,6 +41,10 @@ def search(term, limit=10, clean=True, strict=True):
|
|
| 41 |
'Authorization': f'Bearer {SCITE_API_KEY}'
|
| 42 |
}
|
| 43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return (
|
| 45 |
[remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
|
| 46 |
[(doc['doi'], doc['citations'], doc['title'])
|
|
@@ -80,6 +84,7 @@ def init_models():
|
|
| 80 |
|
| 81 |
qa_model, reranker, stop, device = init_models()
|
| 82 |
|
|
|
|
| 83 |
def clean_query(query, strict=True, clean=True):
|
| 84 |
operator = ' '
|
| 85 |
if strict:
|
|
@@ -136,6 +141,10 @@ st.markdown("""
|
|
| 136 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
| 137 |
""", unsafe_allow_html=True)
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def run_query(query):
|
| 140 |
if device == 'cpu':
|
| 141 |
limit = 50
|
|
@@ -143,7 +152,7 @@ def run_query(query):
|
|
| 143 |
else:
|
| 144 |
limit = 100
|
| 145 |
context_limit = 25
|
| 146 |
-
contexts, orig_docs = search(query, limit=limit)
|
| 147 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 148 |
return st.markdown("""
|
| 149 |
<div class="container-fluid">
|
|
|
|
| 41 |
'Authorization': f'Bearer {SCITE_API_KEY}'
|
| 42 |
}
|
| 43 |
)
|
| 44 |
+
try:
|
| 45 |
+
req.json()
|
| 46 |
+
except:
|
| 47 |
+
return [], []
|
| 48 |
return (
|
| 49 |
[remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']],
|
| 50 |
[(doc['doi'], doc['citations'], doc['title'])
|
|
|
|
| 84 |
|
| 85 |
qa_model, reranker, stop, device = init_models()
|
| 86 |
|
| 87 |
+
|
| 88 |
def clean_query(query, strict=True, clean=True):
|
| 89 |
operator = ' '
|
| 90 |
if strict:
|
|
|
|
| 141 |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
| 142 |
""", unsafe_allow_html=True)
|
| 143 |
|
| 144 |
+
strict_mode = st.radio(
|
| 145 |
+
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
| 146 |
+
('strict', 'lenient'))
|
| 147 |
+
|
| 148 |
def run_query(query):
|
| 149 |
if device == 'cpu':
|
| 150 |
limit = 50
|
|
|
|
| 152 |
else:
|
| 153 |
limit = 100
|
| 154 |
context_limit = 25
|
| 155 |
+
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
| 156 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 157 |
return st.markdown("""
|
| 158 |
<div class="container-fluid">
|