Spaces:
Runtime error
Runtime error
works
Browse files- .gitattributes +1 -0
- app.py +45 -6
- imgs/akinator_ready.png +0 -0
- requirements.txt +3 -2
.gitattributes
CHANGED
|
@@ -31,3 +31,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,13 +1,52 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
st.markdown("### Hello, world!")
|
| 4 |
-
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
|
| 5 |
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 5 |
+
import json
|
| 6 |
+
import streamlit.components.v1 as components
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
st.markdown("### Arxiv paper classifier (No guarantees provided)")
|
| 11 |
|
| 12 |
+
col1, col2 = st.columns([1, 1])
|
| 13 |
+
col1.image('imgs/akinator_ready.png', width=200)
|
| 14 |
+
btn = col2.button('Classify!')
|
| 15 |
|
| 16 |
+
model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000')
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
| 18 |
|
| 19 |
+
with open('checkpoint-3000/config.json', 'r') as f:
|
| 20 |
+
id2label = json.load(f)['id2label']
|
| 21 |
|
| 22 |
+
id2label = {int(key): value for key, value in id2label.items()}
|
| 23 |
+
title = st.text_area(label='Input title...', placeholder='Input title...', label_visibility='hidden', height=3)
|
| 24 |
+
abstract = st.text_area(label='Input title...', placeholder='Input abstract...', label_visibility='hidden', height=10)
|
| 25 |
+
text = '\n'.join([title, abstract])
|
| 26 |
+
|
| 27 |
+
if btn and len(text) == 1:
|
| 28 |
+
st.error('Title and abstract are empty!')
|
| 29 |
+
|
| 30 |
+
if btn and len(text) > 1:
|
| 31 |
+
tokenized = tokenizer(text)
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0))
|
| 35 |
+
_, ids = torch.sort(-out['logits'])
|
| 36 |
+
probs = F.softmax(out['logits'][0, ids], dim=1)
|
| 37 |
+
ids, probs = ids[0], probs[0]
|
| 38 |
+
|
| 39 |
+
ptotal = 0
|
| 40 |
+
result = []
|
| 41 |
+
for i, prob in enumerate(probs):
|
| 42 |
+
ptotal += prob
|
| 43 |
+
result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})')
|
| 44 |
+
output = '<br>'.join(result)
|
| 45 |
+
|
| 46 |
+
components.html(f'<div>'
|
| 47 |
+
f'<div style="height:120px;width:680px;'
|
| 48 |
+
f'border:1px solid #ccc;border-color: red;'
|
| 49 |
+
f'font:16px/26px Georgia, Garamond, Serif;'
|
| 50 |
+
f'overflow:scroll;'
|
| 51 |
+
f'color:white;">'
|
| 52 |
+
f'{output}</div>')
|
imgs/akinator_ready.png
ADDED
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
-
transformers
|
| 2 |
-
torch
|
|
|
|
| 3 |
|
|
|
|
| 1 |
+
transformers==4.15.0
|
| 2 |
+
torch==1.12.1
|
| 3 |
+
|
| 4 |
|