Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| ) | |
| from typing import Dict | |
| import os | |
| import pandas as pd | |
| from huggingface_hub import login | |
| login(token=os.getenv("HUGGINGFACE_TOKEN")) | |
| FOUNDATIONS = ["authority", "care", "fairness", "loyalty", "sanctity"] | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "joshnguyen/mformer-authority", | |
| use_auth_token=True | |
| ) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODELS = {} | |
| for foundation in FOUNDATIONS: | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=f"joshnguyen/mformer-{foundation}", | |
| use_auth_token=True | |
| ) | |
| model.eval() | |
| MODELS[foundation] = model.to(DEVICE) | |
| def classify_text(text: str) -> Dict[str, float]: | |
| # Encode the prompt | |
| inputs = tokenizer([text], | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt').to(DEVICE) | |
| scores = [] | |
| for foundation in FOUNDATIONS: | |
| model = MODELS[foundation] | |
| outputs = model(**inputs) | |
| outputs = torch.softmax(outputs.logits, dim=1) | |
| outputs = outputs[:, 1] | |
| score = outputs.detach().cpu().numpy()[0] | |
| scores.append([foundation.capitalize(), score]) | |
| scores = pd.DataFrame(scores, columns=["foundation", "score"]) | |
| return scores | |
| demo = gr.Interface( | |
| fn=classify_text, | |
| inputs=[ | |
| # Prompt | |
| gr.Textbox( | |
| label="Input text", | |
| container=False, | |
| show_label=True, | |
| placeholder="Enter some text...", | |
| lines=12, | |
| scale=10, | |
| ), | |
| ], | |
| outputs=[ | |
| gr.BarPlot( | |
| x="foundation", | |
| y="score", | |
| title="Moral foundations scores", | |
| x_title=" ", | |
| y_title=" ", | |
| y_lim=[0, 1], | |
| vertical=False, | |
| height=200, | |
| width=500, | |
| ) | |
| ], | |
| ) | |
| demo.queue(max_size=20).launch() | |