Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
from transformers import pipeline
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
| 3 |
|
| 4 |
models = {
|
| 5 |
-
'devngho/ko_edu_classifier_v2_nlpai-lab_KoE5': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_nlpai-lab_KoE5"),
|
| 6 |
-
'devngho/ko_edu_classifier_v2_lemon-mint_LaBSE-EnKo-Nano-Preview-v0.3': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_lemon-mint_LaBSE-EnKo-Nano-Preview-v0.3"),
|
| 7 |
-
'devngho/ko_edu_classifier_v2_LaBSE': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_LaBSE")
|
| 8 |
}
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
|
|
|
|
| 12 |
def evaluate_model(input_text):
|
| 13 |
return [model(input_text)[0]['score'] * 6 if model_name != 'devngho/ko_edu_classifier_v2_nlpai-lab_KoE5' else model('passage: ' + input_text)[0]['score'] * 6 for model_name, model in models.items()]
|
| 14 |
|
|
|
|
| 1 |
from transformers import pipeline
|
| 2 |
import gradio as gr
|
| 3 |
+
import spaces
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
models = {
|
| 7 |
+
'devngho/ko_edu_classifier_v2_nlpai-lab_KoE5': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_nlpai-lab_KoE5", device='cuda', torch_dtype=torch.bfloat16),
|
| 8 |
+
'devngho/ko_edu_classifier_v2_lemon-mint_LaBSE-EnKo-Nano-Preview-v0.3': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_lemon-mint_LaBSE-EnKo-Nano-Preview-v0.3", device='cuda', torch_dtype=torch.bfloat16),
|
| 9 |
+
'devngho/ko_edu_classifier_v2_LaBSE': pipeline("text-classification", model="devngho/ko_edu_classifier_v2_LaBSE", device='cuda', torch_dtype=torch.bfloat16)
|
| 10 |
}
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
+
@spaces.GPU
|
| 15 |
def evaluate_model(input_text):
|
| 16 |
return [model(input_text)[0]['score'] * 6 if model_name != 'devngho/ko_edu_classifier_v2_nlpai-lab_KoE5' else model('passage: ' + input_text)[0]['score'] * 6 for model_name, model in models.items()]
|
| 17 |
|