Spaces:
Runtime error
Runtime error
Commit
Β·
f5555cd
1
Parent(s):
2d39184
experiment with summarization
Browse files
app.py
CHANGED
|
@@ -149,9 +149,10 @@ def init_models():
|
|
| 149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
-
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
| 155 |
|
| 156 |
|
| 157 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -212,6 +213,9 @@ st.markdown("""
|
|
| 212 |
""", unsafe_allow_html=True)
|
| 213 |
|
| 214 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
|
|
|
|
|
|
|
|
|
| 215 |
support_all = st.radio(
|
| 216 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 217 |
('yes', 'no'))
|
|
@@ -267,6 +271,21 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
| 267 |
return None
|
| 268 |
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
def run_query(query):
|
| 271 |
# if use_query_exp == 'yes':
|
| 272 |
# query_exp = paraphrase(f"question2question: {query}")
|
|
@@ -275,10 +294,6 @@ def run_query(query):
|
|
| 275 |
# * {query_exp}
|
| 276 |
# """)
|
| 277 |
|
| 278 |
-
# address period in highlitht avoidability. Risk factors
|
| 279 |
-
# address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001].
|
| 280 |
-
# address highlight html
|
| 281 |
-
|
| 282 |
# could also try fallback if there are no good answers by score...
|
| 283 |
limit = top_hits_limit or 100
|
| 284 |
context_limit = context_lim or 10
|
|
@@ -346,10 +361,13 @@ def run_query(query):
|
|
| 346 |
else:
|
| 347 |
threshold = (confidence_threshold or 10) / 100
|
| 348 |
|
| 349 |
-
sorted_result = filter(
|
| 350 |
lambda x: x['score'] > threshold,
|
| 351 |
sorted_result
|
| 352 |
-
)
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
for r in sorted_result:
|
| 355 |
ctx = remove_html(r["context"])
|
|
|
|
| 149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 152 |
+
summarizer = pipeline("summarization")
|
| 153 |
+
return question_answerer, reranker, stop, device, summarizer
|
| 154 |
|
| 155 |
+
qa_model, reranker, stop, device, summarizer = init_models() # queryexp_model, queryexp_tokenizer
|
| 156 |
|
| 157 |
|
| 158 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 213 |
""", unsafe_allow_html=True)
|
| 214 |
|
| 215 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 216 |
+
use_mds = st.radio(
|
| 217 |
+
"Use multi-document summarization to summarize answer?",
|
| 218 |
+
('yes', 'no'))
|
| 219 |
support_all = st.radio(
|
| 220 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 221 |
('yes', 'no'))
|
|
|
|
| 271 |
return None
|
| 272 |
|
| 273 |
|
| 274 |
+
def gen_summary(query, sorted_result):
|
| 275 |
+
doc_sep = '\n'
|
| 276 |
+
summary = summarizer(f'{query} '.join([f'{doc_sep}'.join(r['texts']) + r['context'] for r in sorted_result]))[0]['summary_text']
|
| 277 |
+
st.markdown(f"""
|
| 278 |
+
<div class="container-fluid">
|
| 279 |
+
<div class="row align-items-start">
|
| 280 |
+
<div class="col-md-12 col-sm-12">
|
| 281 |
+
<strong>Answer:</strong> {summary}
|
| 282 |
+
</div>
|
| 283 |
+
</div>
|
| 284 |
+
</div>
|
| 285 |
+
""", unsafe_allow_html=True)
|
| 286 |
+
st.markdown("<br /><br /><h5>Sources:</h5>", unsafe_allow_html=True)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
def run_query(query):
|
| 290 |
# if use_query_exp == 'yes':
|
| 291 |
# query_exp = paraphrase(f"question2question: {query}")
|
|
|
|
| 294 |
# * {query_exp}
|
| 295 |
# """)
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
# could also try fallback if there are no good answers by score...
|
| 298 |
limit = top_hits_limit or 100
|
| 299 |
context_limit = context_lim or 10
|
|
|
|
| 361 |
else:
|
| 362 |
threshold = (confidence_threshold or 10) / 100
|
| 363 |
|
| 364 |
+
sorted_result = list(filter(
|
| 365 |
lambda x: x['score'] > threshold,
|
| 366 |
sorted_result
|
| 367 |
+
))
|
| 368 |
+
|
| 369 |
+
if use_mds == 'yes':
|
| 370 |
+
gen_summary(query, sorted_result)
|
| 371 |
|
| 372 |
for r in sorted_result:
|
| 373 |
ctx = remove_html(r["context"])
|