Spaces:
Runtime error
Runtime error
Commit
Β·
f1fd3e1
1
Parent(s):
f5555cd
use ms2 for summarization
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from transformers import pipeline
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
import nltk
|
|
@@ -149,10 +149,11 @@ def init_models():
|
|
| 149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
-
qa_model, reranker, stop, device,
|
| 156 |
|
| 157 |
|
| 158 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -270,15 +271,71 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
| 270 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
| 271 |
return None
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
def gen_summary(query, sorted_result):
|
| 275 |
-
|
| 276 |
-
|
|
|
|
| 277 |
st.markdown(f"""
|
| 278 |
<div class="container-fluid">
|
| 279 |
<div class="row align-items-start">
|
| 280 |
<div class="col-md-12 col-sm-12">
|
| 281 |
-
<strong>Answer:</strong> {
|
| 282 |
</div>
|
| 283 |
</div>
|
| 284 |
</div>
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from transformers import pipeline, AutoTokenizer, LEDForConditionalGeneration
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
import nltk
|
|
|
|
| 149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
| 150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 152 |
+
summ_tok = AutoTokenizer.from_pretrained('allenai/led-base-16384-ms2')
|
| 153 |
+
summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
|
| 154 |
+
return question_answerer, reranker, stop, device, summ_mdl, summ_tok
|
| 155 |
|
| 156 |
+
qa_model, reranker, stop, device, summ_mdl, summ_tok = init_models() # queryexp_model, queryexp_tokenizer
|
| 157 |
|
| 158 |
|
| 159 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 271 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
| 272 |
return None
|
| 273 |
|
| 274 |
+
def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
|
| 275 |
+
input_ids_all=[]
|
| 276 |
+
for data in documents:
|
| 277 |
+
all_docs = data.split("|||||")
|
| 278 |
+
for i, doc in enumerate(all_docs):
|
| 279 |
+
doc = doc.replace("\n", " ")
|
| 280 |
+
doc = " ".join(doc.split())
|
| 281 |
+
all_docs[i] = doc
|
| 282 |
+
|
| 283 |
+
#### concat with global attention on doc-sep
|
| 284 |
+
input_ids = []
|
| 285 |
+
for doc in all_docs:
|
| 286 |
+
input_ids.extend(
|
| 287 |
+
tokenizer.encode(
|
| 288 |
+
doc,
|
| 289 |
+
truncation=True,
|
| 290 |
+
max_length=4096 // len(all_docs),
|
| 291 |
+
)[1:-1]
|
| 292 |
+
)
|
| 293 |
+
input_ids.append(docsep_token_id)
|
| 294 |
+
input_ids = (
|
| 295 |
+
[tokenizer.bos_token_id]
|
| 296 |
+
+ input_ids
|
| 297 |
+
+ [tokenizer.eos_token_id]
|
| 298 |
+
)
|
| 299 |
+
input_ids_all.append(torch.tensor(input_ids))
|
| 300 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 301 |
+
input_ids_all, batch_first=True, padding_value=pad_token_id
|
| 302 |
+
)
|
| 303 |
+
return input_ids
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
|
| 307 |
+
input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
|
| 308 |
+
# get the input ids and attention masks together
|
| 309 |
+
global_attention_mask = torch.zeros_like(input_ids).to(device)
|
| 310 |
+
input_ids = input_ids.to(device)
|
| 311 |
+
# put global attention on <s> token
|
| 312 |
+
|
| 313 |
+
global_attention_mask[:, 0] = 1
|
| 314 |
+
global_attention_mask[input_ids == docsep_token_id] = 1
|
| 315 |
+
generated_ids = model.generate(
|
| 316 |
+
input_ids=input_ids,
|
| 317 |
+
global_attention_mask=global_attention_mask,
|
| 318 |
+
use_cache=True,
|
| 319 |
+
max_length=1024,
|
| 320 |
+
num_beams=5,
|
| 321 |
+
)
|
| 322 |
+
generated_str = tokenizer.batch_decode(
|
| 323 |
+
generated_ids.tolist(), skip_special_tokens=True
|
| 324 |
+
)
|
| 325 |
+
result={}
|
| 326 |
+
result['generated_summaries'] = generated_str
|
| 327 |
+
return result
|
| 328 |
+
|
| 329 |
|
| 330 |
def gen_summary(query, sorted_result):
|
| 331 |
+
pad_token_id = summ_tok.pad_token_id
|
| 332 |
+
docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
|
| 333 |
+
out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
|
| 334 |
st.markdown(f"""
|
| 335 |
<div class="container-fluid">
|
| 336 |
<div class="row align-items-start">
|
| 337 |
<div class="col-md-12 col-sm-12">
|
| 338 |
+
<strong>Answer:</strong> {out['generated_summaries'][0]}
|
| 339 |
</div>
|
| 340 |
</div>
|
| 341 |
</div>
|