Spaces:
Runtime error
Runtime error
Commit
Β·
5cc7b84
1
Parent(s):
f1fd3e1
remove summarization
Browse files
app.py
CHANGED
|
@@ -78,7 +78,6 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
| 78 |
except:
|
| 79 |
pass
|
| 80 |
|
| 81 |
-
|
| 82 |
return (
|
| 83 |
contexts,
|
| 84 |
docs
|
|
@@ -149,11 +148,9 @@ def init_models():
|
|
| 149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 152 |
-
|
| 153 |
-
summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
|
| 154 |
-
return question_answerer, reranker, stop, device, summ_mdl, summ_tok
|
| 155 |
|
| 156 |
-
qa_model, reranker, stop, device
|
| 157 |
|
| 158 |
|
| 159 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -214,9 +211,6 @@ st.markdown("""
|
|
| 214 |
""", unsafe_allow_html=True)
|
| 215 |
|
| 216 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
| 217 |
-
use_mds = st.radio(
|
| 218 |
-
"Use multi-document summarization to summarize answer?",
|
| 219 |
-
('yes', 'no'))
|
| 220 |
support_all = st.radio(
|
| 221 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 222 |
('yes', 'no'))
|
|
@@ -271,77 +265,6 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
| 271 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
| 272 |
return None
|
| 273 |
|
| 274 |
-
def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
|
| 275 |
-
input_ids_all=[]
|
| 276 |
-
for data in documents:
|
| 277 |
-
all_docs = data.split("|||||")
|
| 278 |
-
for i, doc in enumerate(all_docs):
|
| 279 |
-
doc = doc.replace("\n", " ")
|
| 280 |
-
doc = " ".join(doc.split())
|
| 281 |
-
all_docs[i] = doc
|
| 282 |
-
|
| 283 |
-
#### concat with global attention on doc-sep
|
| 284 |
-
input_ids = []
|
| 285 |
-
for doc in all_docs:
|
| 286 |
-
input_ids.extend(
|
| 287 |
-
tokenizer.encode(
|
| 288 |
-
doc,
|
| 289 |
-
truncation=True,
|
| 290 |
-
max_length=4096 // len(all_docs),
|
| 291 |
-
)[1:-1]
|
| 292 |
-
)
|
| 293 |
-
input_ids.append(docsep_token_id)
|
| 294 |
-
input_ids = (
|
| 295 |
-
[tokenizer.bos_token_id]
|
| 296 |
-
+ input_ids
|
| 297 |
-
+ [tokenizer.eos_token_id]
|
| 298 |
-
)
|
| 299 |
-
input_ids_all.append(torch.tensor(input_ids))
|
| 300 |
-
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 301 |
-
input_ids_all, batch_first=True, padding_value=pad_token_id
|
| 302 |
-
)
|
| 303 |
-
return input_ids
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
|
| 307 |
-
input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
|
| 308 |
-
# get the input ids and attention masks together
|
| 309 |
-
global_attention_mask = torch.zeros_like(input_ids).to(device)
|
| 310 |
-
input_ids = input_ids.to(device)
|
| 311 |
-
# put global attention on <s> token
|
| 312 |
-
|
| 313 |
-
global_attention_mask[:, 0] = 1
|
| 314 |
-
global_attention_mask[input_ids == docsep_token_id] = 1
|
| 315 |
-
generated_ids = model.generate(
|
| 316 |
-
input_ids=input_ids,
|
| 317 |
-
global_attention_mask=global_attention_mask,
|
| 318 |
-
use_cache=True,
|
| 319 |
-
max_length=1024,
|
| 320 |
-
num_beams=5,
|
| 321 |
-
)
|
| 322 |
-
generated_str = tokenizer.batch_decode(
|
| 323 |
-
generated_ids.tolist(), skip_special_tokens=True
|
| 324 |
-
)
|
| 325 |
-
result={}
|
| 326 |
-
result['generated_summaries'] = generated_str
|
| 327 |
-
return result
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
def gen_summary(query, sorted_result):
|
| 331 |
-
pad_token_id = summ_tok.pad_token_id
|
| 332 |
-
docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
|
| 333 |
-
out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
|
| 334 |
-
st.markdown(f"""
|
| 335 |
-
<div class="container-fluid">
|
| 336 |
-
<div class="row align-items-start">
|
| 337 |
-
<div class="col-md-12 col-sm-12">
|
| 338 |
-
<strong>Answer:</strong> {out['generated_summaries'][0]}
|
| 339 |
-
</div>
|
| 340 |
-
</div>
|
| 341 |
-
</div>
|
| 342 |
-
""", unsafe_allow_html=True)
|
| 343 |
-
st.markdown("<br /><br /><h5>Sources:</h5>", unsafe_allow_html=True)
|
| 344 |
-
|
| 345 |
|
| 346 |
def run_query(query):
|
| 347 |
# if use_query_exp == 'yes':
|
|
@@ -395,7 +318,7 @@ def run_query(query):
|
|
| 395 |
context = '\n---'.join(contexts[:context_limit])
|
| 396 |
|
| 397 |
results = []
|
| 398 |
-
model_results = qa_model(question=query, context=context, top_k=10)
|
| 399 |
for result in model_results:
|
| 400 |
matched = matched_context(result['start'], result['end'], context)
|
| 401 |
support = find_source(result['answer'], orig_docs, matched)
|
|
@@ -423,9 +346,6 @@ def run_query(query):
|
|
| 423 |
sorted_result
|
| 424 |
))
|
| 425 |
|
| 426 |
-
if use_mds == 'yes':
|
| 427 |
-
gen_summary(query, sorted_result)
|
| 428 |
-
|
| 429 |
for r in sorted_result:
|
| 430 |
ctx = remove_html(r["context"])
|
| 431 |
for answer in r['texts']:
|
|
|
|
| 78 |
except:
|
| 79 |
pass
|
| 80 |
|
|
|
|
| 81 |
return (
|
| 82 |
contexts,
|
| 83 |
docs
|
|
|
|
| 148 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
+
return question_answerer, reranker, stop, device
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
| 154 |
|
| 155 |
|
| 156 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 211 |
""", unsafe_allow_html=True)
|
| 212 |
|
| 213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
|
|
|
|
|
|
|
|
|
| 214 |
support_all = st.radio(
|
| 215 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
| 216 |
('yes', 'no'))
|
|
|
|
| 265 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
| 266 |
return None
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
def run_query(query):
|
| 270 |
# if use_query_exp == 'yes':
|
|
|
|
| 318 |
context = '\n---'.join(contexts[:context_limit])
|
| 319 |
|
| 320 |
results = []
|
| 321 |
+
model_results = qa_model(question=query, context=query+'---'+context, top_k=10)
|
| 322 |
for result in model_results:
|
| 323 |
matched = matched_context(result['start'], result['end'], context)
|
| 324 |
support = find_source(result['answer'], orig_docs, matched)
|
|
|
|
| 346 |
sorted_result
|
| 347 |
))
|
| 348 |
|
|
|
|
|
|
|
|
|
|
| 349 |
for r in sorted_result:
|
| 350 |
ctx = remove_html(r["context"])
|
| 351 |
for answer in r['texts']:
|