Spaces:
Sleeping
Sleeping
Commit
·
bb78cda
1
Parent(s):
b188ae7
Add inference and explanation run time
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
|
|
| 4 |
pipeline)
|
| 5 |
import shap
|
| 6 |
from PIL import Image
|
|
|
|
| 7 |
|
| 8 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
| 9 |
output_width = 800
|
|
@@ -33,16 +34,26 @@ tokenizer, model = load_model(model_name)
|
|
| 33 |
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
|
| 34 |
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
|
| 35 |
|
| 36 |
-
col1, col2 = st.columns(
|
| 37 |
text = col1.text_area("Enter text input", value = "Classify me.")
|
| 38 |
|
|
|
|
| 39 |
result = pred(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
top_pred = result[0][0]['label']
|
| 41 |
col2.write('')
|
| 42 |
for label in result[0]:
|
| 43 |
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
|
| 44 |
|
| 45 |
shap_values = explainer([text])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
force_plot = shap.plots.text(shap_values, display=False)
|
| 48 |
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
|
|
@@ -58,4 +69,4 @@ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Pr
|
|
| 58 |
st.pyplot(bar_plot, clear_figure=True)
|
| 59 |
|
| 60 |
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
|
| 61 |
-
components.html(force_plot, height=output_height, width=output_width, scrolling=True)
|
|
|
|
| 4 |
pipeline)
|
| 5 |
import shap
|
| 6 |
from PIL import Image
|
| 7 |
+
import time
|
| 8 |
|
| 9 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
| 10 |
output_width = 800
|
|
|
|
| 34 |
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
|
| 35 |
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
|
| 36 |
|
| 37 |
+
col1, col2, col3 = st.columns(3)
|
| 38 |
text = col1.text_area("Enter text input", value = "Classify me.")
|
| 39 |
|
| 40 |
+
start_time = time.time()
|
| 41 |
result = pred(text)
|
| 42 |
+
inference_time = time.time() - start_time
|
| 43 |
+
|
| 44 |
+
col3.write('')
|
| 45 |
+
col3.write(f'**Inference Time:** {inference_time: .4f}')
|
| 46 |
+
|
| 47 |
top_pred = result[0][0]['label']
|
| 48 |
col2.write('')
|
| 49 |
for label in result[0]:
|
| 50 |
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
|
| 51 |
|
| 52 |
shap_values = explainer([text])
|
| 53 |
+
explanation_time = shap_values.compute_time
|
| 54 |
+
|
| 55 |
+
col3.write('')
|
| 56 |
+
col3.write(f'**Explanation Time:** {explanation_time: .4f}')
|
| 57 |
|
| 58 |
force_plot = shap.plots.text(shap_values, display=False)
|
| 59 |
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
|
|
|
|
| 69 |
st.pyplot(bar_plot, clear_figure=True)
|
| 70 |
|
| 71 |
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
|
| 72 |
+
components.html(force_plot, height=output_height, width=output_width, scrolling=True)
|