Update app.py
Browse files
app.py
CHANGED
|
@@ -7,8 +7,6 @@ import rdflib
|
|
| 7 |
from rdflib.plugins.sparql.parser import parseQuery
|
| 8 |
from huggingface_hub import InferenceClient
|
| 9 |
import re
|
| 10 |
-
import torch
|
| 11 |
-
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
| 12 |
# ---------------------------------------------------------------------------
|
| 13 |
# CONFIGURAZIONE LOGGING
|
| 14 |
# ---------------------------------------------------------------------------
|
|
@@ -18,22 +16,6 @@ logging.basicConfig(
|
|
| 18 |
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()]
|
| 19 |
)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# Determina il device (GPU se disponibile, altrimenti CPU)
|
| 24 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
-
logger.info(f"Device per il classificatore: {device}")
|
| 26 |
-
|
| 27 |
-
# Carica il modello e il tokenizer del classificatore fine-tuned
|
| 28 |
-
try:
|
| 29 |
-
logger.info("Caricamento del modello di classificazione fine-tuned da 'finetuned-bert-model'.")
|
| 30 |
-
classifier_model = DistilBertForSequenceClassification.from_pretrained("finetuned-bert-model")
|
| 31 |
-
classifier_tokenizer = DistilBertTokenizer.from_pretrained("finetuned-bert-model")
|
| 32 |
-
classifier_model.to(device)
|
| 33 |
-
logger.info("Modello di classificazione caricato correttamente.")
|
| 34 |
-
except Exception as e:
|
| 35 |
-
logger.error(f"Errore nel caricamento del modello di classificazione: {e}")
|
| 36 |
-
classifier_model = None
|
| 37 |
explanation_dict = {}
|
| 38 |
# ---------------------------------------------------------------------------
|
| 39 |
# COSTANTI / CHIAVI / MODELLI
|
|
@@ -402,32 +384,6 @@ def assistant_endpoint(req: AssistantRequest):
|
|
| 402 |
max_tokens = req.max_tokens
|
| 403 |
temperature = req.temperature
|
| 404 |
logger.debug(f"Parametri utente: message='{user_message}', max_tokens={max_tokens}, temperature={temperature}")
|
| 405 |
-
# -------------------------------
|
| 406 |
-
# CLASSIFICAZIONE DEL TESTO RICEVUTO
|
| 407 |
-
# -------------------------------
|
| 408 |
-
if classifier_model is not None:
|
| 409 |
-
try:
|
| 410 |
-
# Prepara l'input per il modello di classificazione
|
| 411 |
-
inputs = classifier_tokenizer(user_message, return_tensors="pt", truncation=True, padding=True)
|
| 412 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 413 |
-
|
| 414 |
-
# Disattiva il calcolo del gradiente per velocizzare l'inferenza
|
| 415 |
-
with torch.no_grad():
|
| 416 |
-
outputs = classifier_model(**inputs)
|
| 417 |
-
logits = outputs.logits
|
| 418 |
-
pred = torch.argmax(logits, dim=1).item()
|
| 419 |
-
|
| 420 |
-
# Mappa l'etichetta numerica a una stringa (modifica secondo la tua logica)
|
| 421 |
-
label_mapping = {0: "NON PERTINENTE", 1: "PERTINENTE"}
|
| 422 |
-
classification_result = label_mapping.get(pred, f"Etichetta {pred}")
|
| 423 |
-
logger.info(f"[Classificazione] La domanda classificata come: {classification_result}")
|
| 424 |
-
explanation_dict['classification'] = f"Risultato classificazione: {classification_result}"
|
| 425 |
-
except Exception as e:
|
| 426 |
-
logger.error(f"Errore durante la classificazione della domanda: {e}")
|
| 427 |
-
explanation_dict['classification'] = f"Errore classificazione: {e}"
|
| 428 |
-
else:
|
| 429 |
-
logger.warning("Modello di classificazione non disponibile.")
|
| 430 |
-
explanation_dict['classification'] = "Modello di classificazione non disponibile."
|
| 431 |
# -----------------------------------------------------------------------
|
| 432 |
# STEP 1: Generazione della query SPARQL
|
| 433 |
# -----------------------------------------------------------------------
|
|
|
|
| 7 |
from rdflib.plugins.sparql.parser import parseQuery
|
| 8 |
from huggingface_hub import InferenceClient
|
| 9 |
import re
|
|
|
|
|
|
|
| 10 |
# ---------------------------------------------------------------------------
|
| 11 |
# CONFIGURAZIONE LOGGING
|
| 12 |
# ---------------------------------------------------------------------------
|
|
|
|
| 16 |
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()]
|
| 17 |
)
|
| 18 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
explanation_dict = {}
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# COSTANTI / CHIAVI / MODELLI
|
|
|
|
| 384 |
max_tokens = req.max_tokens
|
| 385 |
temperature = req.temperature
|
| 386 |
logger.debug(f"Parametri utente: message='{user_message}', max_tokens={max_tokens}, temperature={temperature}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
# -----------------------------------------------------------------------
|
| 388 |
# STEP 1: Generazione della query SPARQL
|
| 389 |
# -----------------------------------------------------------------------
|