Spaces:
Runtime error
Runtime error
Commit
Β·
3f1f616
1
Parent(s):
bdb2b00
update to use api
Browse files
app.py
CHANGED
|
@@ -12,15 +12,15 @@ import torch
|
|
| 12 |
|
| 13 |
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
|
| 14 |
|
| 15 |
-
class CrossEncoder:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
|
| 26 |
def remove_html(x):
|
|
@@ -134,23 +134,23 @@ def find_source(text, docs, matched):
|
|
| 134 |
return None
|
| 135 |
|
| 136 |
|
| 137 |
-
@st.experimental_singleton
|
| 138 |
-
def init_models():
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
| 153 |
-
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
| 154 |
|
| 155 |
|
| 156 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -206,32 +206,32 @@ Answers are linked to source documents containing citations where users can expl
|
|
| 206 |
For example try: Do tanning beds cause cancer?
|
| 207 |
""")
|
| 208 |
|
| 209 |
-
st.markdown("""
|
| 210 |
-
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
| 211 |
-
""", unsafe_allow_html=True)
|
| 212 |
-
|
| 213 |
-
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
# def paraphrase(text, max_length=128):
|
| 237 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
@@ -272,38 +272,120 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
| 272 |
return None
|
| 273 |
|
| 274 |
|
| 275 |
-
def
|
| 276 |
-
# if use_query_exp == 'yes':
|
| 277 |
-
# query_exp = paraphrase(f"question2question: {query}")
|
| 278 |
-
# st.markdown(f"""
|
| 279 |
-
# If you are not getting good results try one of:
|
| 280 |
-
# * {query_exp}
|
| 281 |
-
# """)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
-
if len(
|
| 307 |
return st.markdown("""
|
| 308 |
<div class="container-fluid">
|
| 309 |
<div class="row align-items-start">
|
|
@@ -314,58 +396,7 @@ def run_query(query, progress_bar):
|
|
| 314 |
</div>
|
| 315 |
""", unsafe_allow_html=True)
|
| 316 |
|
| 317 |
-
|
| 318 |
-
sentence_pairs = [[query, context] for context in contexts]
|
| 319 |
-
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 320 |
-
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 321 |
-
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 322 |
-
contexts = sorted_contexts[:context_limit]
|
| 323 |
-
else:
|
| 324 |
-
contexts = contexts[:context_limit]
|
| 325 |
-
|
| 326 |
-
progress_bar.progress(50)
|
| 327 |
-
if concat_passages == 'yes':
|
| 328 |
-
context = '\n---'.join(contexts)
|
| 329 |
-
model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
|
| 330 |
-
else:
|
| 331 |
-
context = ['\n---\n'+ctx for ctx in contexts]
|
| 332 |
-
model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
|
| 333 |
-
|
| 334 |
-
results = []
|
| 335 |
-
|
| 336 |
-
progress_bar.progress(75)
|
| 337 |
-
for i, result in enumerate(model_results):
|
| 338 |
-
if concat_passages == 'yes':
|
| 339 |
-
matched = matched_context(result['start'], result['end'], context)
|
| 340 |
-
else:
|
| 341 |
-
matched = matched_context(result['start'], result['end'], context[i])
|
| 342 |
-
support = find_source(result['answer'], orig_docs, matched)
|
| 343 |
-
if not support:
|
| 344 |
-
continue
|
| 345 |
-
results.append({
|
| 346 |
-
"answer": support['text'],
|
| 347 |
-
"title": support['source_title'],
|
| 348 |
-
"link": support['source_link'],
|
| 349 |
-
"context": support['citation_statement'],
|
| 350 |
-
"score": result['score'],
|
| 351 |
-
"doi": support["supporting"]
|
| 352 |
-
})
|
| 353 |
-
|
| 354 |
-
grouped_results = group_results_by_context(results)
|
| 355 |
-
sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
|
| 356 |
-
|
| 357 |
-
if confidence_threshold == 0:
|
| 358 |
-
threshold = 0
|
| 359 |
-
else:
|
| 360 |
-
threshold = (confidence_threshold or 10) / 100
|
| 361 |
-
|
| 362 |
-
sorted_result = list(filter(
|
| 363 |
-
lambda x: x['score'] > threshold,
|
| 364 |
-
sorted_result
|
| 365 |
-
))
|
| 366 |
-
|
| 367 |
-
progress_bar.progress(100)
|
| 368 |
-
for r in sorted_result:
|
| 369 |
ctx = remove_html(r["context"])
|
| 370 |
for answer in r['texts']:
|
| 371 |
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
|
@@ -377,5 +408,4 @@ def run_query(query, progress_bar):
|
|
| 377 |
query = st.text_input("Ask scientific literature a question", "")
|
| 378 |
if query != "":
|
| 379 |
with st.spinner('Loading...'):
|
| 380 |
-
|
| 381 |
-
run_query(query, progress_bar)
|
|
|
|
| 12 |
|
| 13 |
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
|
| 14 |
|
| 15 |
+
# class CrossEncoder:
|
| 16 |
+
# def __init__(self, model_path: str, **kwargs):
|
| 17 |
+
# self.model = CE(model_path, **kwargs)
|
| 18 |
|
| 19 |
+
# def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
|
| 20 |
+
# return self.model.predict(
|
| 21 |
+
# sentences=sentences,
|
| 22 |
+
# batch_size=batch_size,
|
| 23 |
+
# show_progress_bar=show_progress_bar)
|
| 24 |
|
| 25 |
|
| 26 |
def remove_html(x):
|
|
|
|
| 134 |
return None
|
| 135 |
|
| 136 |
|
| 137 |
+
# @st.experimental_singleton
|
| 138 |
+
# def init_models():
|
| 139 |
+
# nltk.download('stopwords')
|
| 140 |
+
# nltk.download('punkt')
|
| 141 |
+
# from nltk.corpus import stopwords
|
| 142 |
+
# stop = set(stopwords.words('english') + list(string.punctuation))
|
| 143 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 144 |
+
# question_answerer = pipeline(
|
| 145 |
+
# "question-answering", model='nlpconnect/roberta-base-squad2-nq',
|
| 146 |
+
# device=0 if torch.cuda.is_available() else -1, handle_impossible_answer=False,
|
| 147 |
+
# )
|
| 148 |
+
# reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
| 149 |
+
# # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 150 |
+
# # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
+
# return question_answerer, reranker, stop, device
|
| 152 |
|
| 153 |
+
# qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
| 154 |
|
| 155 |
|
| 156 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 206 |
For example try: Do tanning beds cause cancer?
|
| 207 |
""")
|
| 208 |
|
| 209 |
+
# st.markdown("""
|
| 210 |
+
# <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
| 211 |
+
# """, unsafe_allow_html=True)
|
| 212 |
+
|
| 213 |
+
# with st.expander("Settings (strictness, context limit, top hits)"):
|
| 214 |
+
# concat_passages = st.radio(
|
| 215 |
+
# "Concatenate passages as one long context?",
|
| 216 |
+
# ('yes', 'no'))
|
| 217 |
+
# present_impossible = st.radio(
|
| 218 |
+
# "Present impossible answers? (if the model thinks its impossible to answer should it still try?)",
|
| 219 |
+
# ('yes', 'no'))
|
| 220 |
+
# support_all = st.radio(
|
| 221 |
+
# "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 222 |
+
# ('no', 'yes'))
|
| 223 |
+
# support_abstracts = st.radio(
|
| 224 |
+
# "Use abstracts as a source document?",
|
| 225 |
+
# ('yes', 'no', 'abstract only'))
|
| 226 |
+
# strict_lenient_mix = st.radio(
|
| 227 |
+
# "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
|
| 228 |
+
# ('mix', 'fallback'))
|
| 229 |
+
# confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
| 230 |
+
# use_reranking = st.radio(
|
| 231 |
+
# "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 232 |
+
# ('yes', 'no'))
|
| 233 |
+
# top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
| 234 |
+
# context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
| 235 |
|
| 236 |
# def paraphrase(text, max_length=128):
|
| 237 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
|
| 272 |
return None
|
| 273 |
|
| 274 |
|
| 275 |
+
# def run_query_full(query, progress_bar):
|
| 276 |
+
# # if use_query_exp == 'yes':
|
| 277 |
+
# # query_exp = paraphrase(f"question2question: {query}")
|
| 278 |
+
# # st.markdown(f"""
|
| 279 |
+
# # If you are not getting good results try one of:
|
| 280 |
+
# # * {query_exp}
|
| 281 |
+
# # """)
|
| 282 |
+
|
| 283 |
+
# # could also try fallback if there are no good answers by score...
|
| 284 |
+
# limit = top_hits_limit or 100
|
| 285 |
+
# context_limit = context_lim or 10
|
| 286 |
+
# contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
|
| 287 |
+
# if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
|
| 288 |
+
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
|
| 289 |
+
# contexts = list(
|
| 290 |
+
# set(contexts_strict + contexts_lenient)
|
| 291 |
+
# )
|
| 292 |
+
# orig_docs = orig_docs_strict + orig_docs_lenient
|
| 293 |
+
# elif strict_lenient_mix == 'mix':
|
| 294 |
+
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
|
| 295 |
+
# contexts = list(
|
| 296 |
+
# set(contexts_strict + contexts_lenient)
|
| 297 |
+
# )
|
| 298 |
+
# orig_docs = orig_docs_strict + orig_docs_lenient
|
| 299 |
+
# else:
|
| 300 |
+
# contexts = list(
|
| 301 |
+
# set(contexts_strict)
|
| 302 |
+
# )
|
| 303 |
+
# orig_docs = orig_docs_strict
|
| 304 |
+
# progress_bar.progress(25)
|
| 305 |
+
|
| 306 |
+
# if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 307 |
+
# return st.markdown("""
|
| 308 |
+
# <div class="container-fluid">
|
| 309 |
+
# <div class="row align-items-start">
|
| 310 |
+
# <div class="col-md-12 col-sm-12">
|
| 311 |
+
# Sorry... no results for that question! Try another...
|
| 312 |
+
# </div>
|
| 313 |
+
# </div>
|
| 314 |
+
# </div>
|
| 315 |
+
# """, unsafe_allow_html=True)
|
| 316 |
+
|
| 317 |
+
# if use_reranking == 'yes':
|
| 318 |
+
# sentence_pairs = [[query, context] for context in contexts]
|
| 319 |
+
# scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 320 |
+
# hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 321 |
+
# sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 322 |
+
# contexts = sorted_contexts[:context_limit]
|
| 323 |
+
# else:
|
| 324 |
+
# contexts = contexts[:context_limit]
|
| 325 |
+
|
| 326 |
+
# progress_bar.progress(50)
|
| 327 |
+
# if concat_passages == 'yes':
|
| 328 |
+
# context = '\n---'.join(contexts)
|
| 329 |
+
# model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
|
| 330 |
+
# else:
|
| 331 |
+
# context = ['\n---\n'+ctx for ctx in contexts]
|
| 332 |
+
# model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
|
| 333 |
+
|
| 334 |
+
# results = []
|
| 335 |
+
|
| 336 |
+
# progress_bar.progress(75)
|
| 337 |
+
# for i, result in enumerate(model_results):
|
| 338 |
+
# if concat_passages == 'yes':
|
| 339 |
+
# matched = matched_context(result['start'], result['end'], context)
|
| 340 |
+
# else:
|
| 341 |
+
# matched = matched_context(result['start'], result['end'], context[i])
|
| 342 |
+
# support = find_source(result['answer'], orig_docs, matched)
|
| 343 |
+
# if not support:
|
| 344 |
+
# continue
|
| 345 |
+
# results.append({
|
| 346 |
+
# "answer": support['text'],
|
| 347 |
+
# "title": support['source_title'],
|
| 348 |
+
# "link": support['source_link'],
|
| 349 |
+
# "context": support['citation_statement'],
|
| 350 |
+
# "score": result['score'],
|
| 351 |
+
# "doi": support["supporting"]
|
| 352 |
+
# })
|
| 353 |
+
|
| 354 |
+
# grouped_results = group_results_by_context(results)
|
| 355 |
+
# sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
|
| 356 |
+
|
| 357 |
+
# if confidence_threshold == 0:
|
| 358 |
+
# threshold = 0
|
| 359 |
+
# else:
|
| 360 |
+
# threshold = (confidence_threshold or 10) / 100
|
| 361 |
+
|
| 362 |
+
# sorted_result = list(filter(
|
| 363 |
+
# lambda x: x['score'] > threshold,
|
| 364 |
+
# sorted_result
|
| 365 |
+
# ))
|
| 366 |
+
|
| 367 |
+
# progress_bar.progress(100)
|
| 368 |
+
# for r in sorted_result:
|
| 369 |
+
# ctx = remove_html(r["context"])
|
| 370 |
+
# for answer in r['texts']:
|
| 371 |
+
# ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
| 372 |
+
# # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
| 373 |
+
# title = r.get("title", '')
|
| 374 |
+
# score = round(round(r["score"], 4) * 100, 2)
|
| 375 |
+
# card(title, ctx, score, r['link'], r['doi'])
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def run_query(query):
|
| 379 |
+
api_location = 'http://74.82.31.93'
|
| 380 |
+
resp_raw = requests.get(
|
| 381 |
+
f'{api_location}/question-answer?query={query}'
|
| 382 |
+
)
|
| 383 |
+
try:
|
| 384 |
+
resp = resp_raw.json()
|
| 385 |
+
except:
|
| 386 |
+
resp = {'results': []}
|
| 387 |
|
| 388 |
+
if len(resp.get('results', [])) == 0:
|
| 389 |
return st.markdown("""
|
| 390 |
<div class="container-fluid">
|
| 391 |
<div class="row align-items-start">
|
|
|
|
| 396 |
</div>
|
| 397 |
""", unsafe_allow_html=True)
|
| 398 |
|
| 399 |
+
for r in resp['results']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
ctx = remove_html(r["context"])
|
| 401 |
for answer in r['texts']:
|
| 402 |
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
|
|
|
| 408 |
query = st.text_input("Ask scientific literature a question", "")
|
| 409 |
if query != "":
|
| 410 |
with st.spinner('Loading...'):
|
| 411 |
+
run_query(query)
|
|
|