File size: 597 Bytes
22d76f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from transformers import ( # pylint: disable=import-error
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    pipeline
)

import logging

class IllnessClassifier(object):
    def __init__(self):
        self.classifier = pipeline("text-classification", model="dsuram/distilbert-mentalhealth-classifier")

    def forward(self, text: str):
        output = self.classifier(text)[0]
        disorder = output['label']
        confidence = output['score']
        logging.info(f"Disorder: {disorder}, Confidence: {confidence}")
        return disorder, confidence