Spaces:
Runtime error
Runtime error
Commit
Β·
00e4b2e
1
Parent(s):
577cb80
add more settings
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from transformers import pipeline
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
from nltk.corpus import stopwords
|
|
@@ -80,9 +80,11 @@ def init_models():
|
|
| 80 |
device=device
|
| 81 |
)
|
| 82 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
qa_model, reranker, stop, device = init_models()
|
| 86 |
|
| 87 |
|
| 88 |
def clean_query(query, strict=True, clean=True):
|
|
@@ -134,7 +136,8 @@ st.title("Scientific Question Answering with Citations")
|
|
| 134 |
|
| 135 |
st.write("""
|
| 136 |
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
|
| 137 |
-
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
|
|
|
|
| 138 |
""")
|
| 139 |
|
| 140 |
st.markdown("""
|
|
@@ -145,13 +148,35 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
| 145 |
strict_mode = st.radio(
|
| 146 |
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
| 147 |
('strict', 'lenient'))
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def run_query(query):
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
| 156 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 157 |
return st.markdown("""
|
|
@@ -164,12 +189,15 @@ def run_query(query):
|
|
| 164 |
</div>
|
| 165 |
""", unsafe_allow_html=True)
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
context = '\n'.join(sorted_contexts[:context_limit])
|
| 173 |
results = []
|
| 174 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 175 |
for result in model_results:
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
from nltk.corpus import stopwords
|
|
|
|
| 80 |
device=device
|
| 81 |
)
|
| 82 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
| 83 |
+
queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 84 |
+
queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
| 85 |
+
return question_answerer, reranker, stop, device, queryexp_model, queryexp_tokenizer
|
| 86 |
|
| 87 |
+
qa_model, reranker, stop, device, queryexp_model, queryexp_tokenizer = init_models()
|
| 88 |
|
| 89 |
|
| 90 |
def clean_query(query, strict=True, clean=True):
|
|
|
|
| 136 |
|
| 137 |
st.write("""
|
| 138 |
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
|
| 139 |
+
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try:
|
| 140 |
+
Are tanning beds safe to use? Does size of venture capital fund correlate with returns?
|
| 141 |
""")
|
| 142 |
|
| 143 |
st.markdown("""
|
|
|
|
| 148 |
strict_mode = st.radio(
|
| 149 |
"Query mode? Strict means all words must match in source snippet. Lenient means only some words must match.",
|
| 150 |
('strict', 'lenient'))
|
| 151 |
+
use_reranking = st.radio(
|
| 152 |
+
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
| 153 |
+
('yes', 'no'))
|
| 154 |
+
use_query_exp = st.radio(
|
| 155 |
+
"(Experimental) use query expansion? Right now it just recommends queries",
|
| 156 |
+
('yes', 'no'))
|
| 157 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 200 if torch.cuda.is_available() else 100)
|
| 158 |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25 if torch.cuda.is_available() else 10)
|
| 159 |
|
| 160 |
+
def paraphrase(text, max_length=128):
|
| 161 |
+
|
| 162 |
+
input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
| 163 |
+
|
| 164 |
+
generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=5, num_beams=5, max_length=max_length)
|
| 165 |
+
|
| 166 |
+
preds = '\n'.join([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
|
| 167 |
+
return preds
|
| 168 |
+
|
| 169 |
+
|
| 170 |
def run_query(query):
|
| 171 |
+
if use_query_exp == 'yes':
|
| 172 |
+
query_exp = paraphrase(f"question2question: {query}")
|
| 173 |
+
st.markdown(f"""
|
| 174 |
+
If you are not getting good results try one of:
|
| 175 |
|
| 176 |
+
{query_exp}
|
| 177 |
+
""")
|
| 178 |
+
limit = top_hits_limit or 100
|
| 179 |
+
context_limit = context_lim or 10
|
| 180 |
contexts, orig_docs = search(query, limit=limit, strict=strict_mode == 'strict')
|
| 181 |
if len(contexts) == 0 or not ''.join(contexts).strip():
|
| 182 |
return st.markdown("""
|
|
|
|
| 189 |
</div>
|
| 190 |
""", unsafe_allow_html=True)
|
| 191 |
|
| 192 |
+
if use_reranking == 'yes':
|
| 193 |
+
sentence_pairs = [[query, context] for context in contexts]
|
| 194 |
+
scores = reranker.predict(sentence_pairs, batch_size=limit, show_progress_bar=False)
|
| 195 |
+
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
| 196 |
+
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
| 197 |
+
context = '\n'.join(sorted_contexts[:context_limit])
|
| 198 |
+
else:
|
| 199 |
+
context = '\n'.join(contexts[:context_limit])
|
| 200 |
|
|
|
|
| 201 |
results = []
|
| 202 |
model_results = qa_model(question=query, context=context, top_k=10)
|
| 203 |
for result in model_results:
|