Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import AutoConfig | |
| from models import CustomClassifier, CustomClassificationConfig | |
| MODEL_ID = "yhamidullah/custom-classifier-demo" | |
| config = CustomClassificationConfig.from_pretrained(MODEL_ID) | |
| model = CustomClassifier.from_pretrained(MODEL_ID) | |
| model.eval() | |
| def predict(input_csv: str): | |
| vec = [float(x) for x in input_csv.split(",")] | |
| if len(vec) != config.input_dim: | |
| return f"Error: Need {config.input_dim} floats" | |
| x = torch.tensor([vec]) | |
| with torch.no_grad(): | |
| logits = model(input_ids=x)["logits"] | |
| pred = logits.argmax(dim=-1).item() | |
| return f"Predicted class: {pred}" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(label="Input Vector (comma-separated)"), | |
| outputs="text", | |
| title="Custom Classifier Demo", | |
| ) | |
| demo.launch() | |