Spaces:
Paused
Paused
File size: 597 Bytes
22d76f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import ( # pylint: disable=import-error
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
pipeline
)
import logging
class IllnessClassifier(object):
def __init__(self):
self.classifier = pipeline("text-classification", model="dsuram/distilbert-mentalhealth-classifier")
def forward(self, text: str):
output = self.classifier(text)[0]
disorder = output['label']
confidence = output['score']
logging.info(f"Disorder: {disorder}, Confidence: {confidence}")
return disorder, confidence |