Spaces:
Runtime error
Runtime error
Commit
Β·
964c419
1
Parent(s):
dd426a1
add proper matching
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import nltk
|
|
| 6 |
import string
|
| 7 |
from streamlit.components.v1 import html
|
| 8 |
from sentence_transformers.cross_encoder import CrossEncoder as CE
|
| 9 |
-
import
|
| 10 |
from typing import List, Tuple
|
| 11 |
import torch
|
| 12 |
|
|
@@ -26,7 +26,7 @@ class CrossEncoder:
|
|
| 26 |
def remove_html(x):
|
| 27 |
soup = BeautifulSoup(x, 'html.parser')
|
| 28 |
text = soup.get_text()
|
| 29 |
-
return text
|
| 30 |
|
| 31 |
|
| 32 |
# 4 searches: strict y/n, supported y/n
|
|
@@ -58,7 +58,7 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
| 58 |
except:
|
| 59 |
pass
|
| 60 |
|
| 61 |
-
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
|
| 62 |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
| 63 |
for doc in req.json()['hits']]
|
| 64 |
|
|
@@ -85,10 +85,12 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
| 85 |
)
|
| 86 |
|
| 87 |
|
| 88 |
-
def find_source(text, docs):
|
| 89 |
for doc in docs:
|
| 90 |
for snippet in doc[1]:
|
| 91 |
if text in remove_html(snippet.get('snippet', '')):
|
|
|
|
|
|
|
| 92 |
new_text = text
|
| 93 |
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
|
| 94 |
if text in sent:
|
|
@@ -98,10 +100,12 @@ def find_source(text, docs):
|
|
| 98 |
'text': new_text,
|
| 99 |
'from': snippet['source'],
|
| 100 |
'supporting': snippet['target'],
|
| 101 |
-
'source_title': remove_html(doc[2]),
|
| 102 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
| 103 |
}
|
| 104 |
if text in remove_html(doc[3]):
|
|
|
|
|
|
|
| 105 |
new_text = text
|
| 106 |
for sent in nltk.sent_tokenize(remove_html(doc[3])):
|
| 107 |
if text in sent:
|
|
@@ -111,7 +115,7 @@ def find_source(text, docs):
|
|
| 111 |
'text': new_text,
|
| 112 |
'from': doc[0],
|
| 113 |
'supporting': doc[0],
|
| 114 |
-
'source_title': "ABSTRACT of " + remove_html(doc[2]),
|
| 115 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
| 116 |
}
|
| 117 |
return None
|
|
@@ -233,6 +237,22 @@ def group_results_by_context(results):
|
|
| 233 |
return list(result_groups.values())
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
def run_query(query):
|
| 237 |
# if use_query_exp == 'yes':
|
| 238 |
# query_exp = paraphrase(f"question2question: {query}")
|
|
@@ -278,19 +298,21 @@ def run_query(query):
|
|
| 278 |
</div>
|
| 279 |
</div>
|
| 280 |
""", unsafe_allow_html=True)
|
|
|
|
| 281 |
if use_reranking == 'yes':
|
| 282 |
sentence_pairs = [[query, context] for context in contexts]
|
| 283 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 284 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 285 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 286 |
-
context = '\n'.join(sorted_contexts[:context_limit])
|
| 287 |
else:
|
| 288 |
-
context = '\n'.join(contexts[:context_limit])
|
| 289 |
|
| 290 |
results = []
|
| 291 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 292 |
for result in model_results:
|
| 293 |
-
|
|
|
|
| 294 |
if not support:
|
| 295 |
continue
|
| 296 |
results.append({
|
|
@@ -316,10 +338,9 @@ def run_query(query):
|
|
| 316 |
)
|
| 317 |
|
| 318 |
for r in sorted_result:
|
| 319 |
-
answer = r["answer"]
|
| 320 |
ctx = remove_html(r["context"])
|
| 321 |
for answer in r['texts']:
|
| 322 |
-
ctx = ctx.replace(answer, f"<mark>{answer}</mark>")
|
| 323 |
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
| 324 |
title = r.get("title", '')
|
| 325 |
score = round(round(r["score"], 4) * 100, 2)
|
|
|
|
| 6 |
import string
|
| 7 |
from streamlit.components.v1 import html
|
| 8 |
from sentence_transformers.cross_encoder import CrossEncoder as CE
|
| 9 |
+
import re
|
| 10 |
from typing import List, Tuple
|
| 11 |
import torch
|
| 12 |
|
|
|
|
| 26 |
def remove_html(x):
|
| 27 |
soup = BeautifulSoup(x, 'html.parser')
|
| 28 |
text = soup.get_text()
|
| 29 |
+
return text.strip()
|
| 30 |
|
| 31 |
|
| 32 |
# 4 searches: strict y/n, supported y/n
|
|
|
|
| 58 |
except:
|
| 59 |
pass
|
| 60 |
|
| 61 |
+
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']]
|
| 62 |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
| 63 |
for doc in req.json()['hits']]
|
| 64 |
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
|
| 88 |
+
def find_source(text, docs, matched):
|
| 89 |
for doc in docs:
|
| 90 |
for snippet in doc[1]:
|
| 91 |
if text in remove_html(snippet.get('snippet', '')):
|
| 92 |
+
if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip():
|
| 93 |
+
continue
|
| 94 |
new_text = text
|
| 95 |
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
|
| 96 |
if text in sent:
|
|
|
|
| 100 |
'text': new_text,
|
| 101 |
'from': snippet['source'],
|
| 102 |
'supporting': snippet['target'],
|
| 103 |
+
'source_title': remove_html(doc[2] or ''),
|
| 104 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
| 105 |
}
|
| 106 |
if text in remove_html(doc[3]):
|
| 107 |
+
if matched and remove_html(doc[3]).strip() != matched.strip():
|
| 108 |
+
continue
|
| 109 |
new_text = text
|
| 110 |
for sent in nltk.sent_tokenize(remove_html(doc[3])):
|
| 111 |
if text in sent:
|
|
|
|
| 115 |
'text': new_text,
|
| 116 |
'from': doc[0],
|
| 117 |
'supporting': doc[0],
|
| 118 |
+
'source_title': "ABSTRACT of " + remove_html(doc[2] or ''),
|
| 119 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
| 120 |
}
|
| 121 |
return None
|
|
|
|
| 237 |
return list(result_groups.values())
|
| 238 |
|
| 239 |
|
| 240 |
+
def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
| 241 |
+
# find seperators to identify start and end
|
| 242 |
+
doc_starts = [0]
|
| 243 |
+
for match in re.finditer(seperator, contexts_string):
|
| 244 |
+
doc_starts.append(match.end())
|
| 245 |
+
|
| 246 |
+
for i in range(len(doc_starts)):
|
| 247 |
+
if i == len(doc_starts) - 1:
|
| 248 |
+
if start_i >= doc_starts[i]:
|
| 249 |
+
return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '')
|
| 250 |
+
|
| 251 |
+
if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]:
|
| 252 |
+
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
|
| 256 |
def run_query(query):
|
| 257 |
# if use_query_exp == 'yes':
|
| 258 |
# query_exp = paraphrase(f"question2question: {query}")
|
|
|
|
| 298 |
</div>
|
| 299 |
</div>
|
| 300 |
""", unsafe_allow_html=True)
|
| 301 |
+
|
| 302 |
if use_reranking == 'yes':
|
| 303 |
sentence_pairs = [[query, context] for context in contexts]
|
| 304 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
| 305 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 306 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 307 |
+
context = '\n---'.join(sorted_contexts[:context_limit])
|
| 308 |
else:
|
| 309 |
+
context = '\n---'.join(contexts[:context_limit])
|
| 310 |
|
| 311 |
results = []
|
| 312 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 313 |
for result in model_results:
|
| 314 |
+
matched = matched_context(result['start'], result['end'], context)
|
| 315 |
+
support = find_source(result['answer'], orig_docs, matched)
|
| 316 |
if not support:
|
| 317 |
continue
|
| 318 |
results.append({
|
|
|
|
| 338 |
)
|
| 339 |
|
| 340 |
for r in sorted_result:
|
|
|
|
| 341 |
ctx = remove_html(r["context"])
|
| 342 |
for answer in r['texts']:
|
| 343 |
+
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
| 344 |
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
| 345 |
title = r.get("title", '')
|
| 346 |
score = round(round(r["score"], 4) * 100, 2)
|