Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import numpy as np | |
| import evaluate | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| # Define label mappings | |
| id2label = {0: "SURGERY", 1: "NON-SURGERY"} | |
| label2id = {"SURGERY": 0, "NON-SURGERY": 1} | |
| # Load evaluation metric | |
| accuracy = evaluate.load("accuracy") | |
| # Define preprocessing function | |
| def preprocess_function(examples): | |
| return tokenizer(examples, truncation=True, padding=True) | |
| # Load model for sequence classification | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "emilyalsentzer/Bio_ClinicalBERT", num_labels=2, id2label=id2label, label2id=label2id | |
| ) | |
| # Define compute_metrics function | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| return accuracy.compute(predictions=predictions, references=labels) | |
| # Define data collator | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| # Define training arguments | |
| training_args = TrainingArguments( | |
| output_dir="my_awesome_model", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| num_train_epochs=2, | |
| weight_decay=0.01, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| push_to_hub=True, | |
| ) | |
| # Initialize trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| # Streamlit UI | |
| st.title("Clinical Text Classification") | |
| text = st.text_area("Enter clinical text:", "") | |
| if st.button("Classify"): | |
| # Tokenize user input and predict | |
| tokenized_text = preprocess_function(text) | |
| result = trainer.predict(tokenized_text) | |
| prediction = np.argmax(result.predictions, axis=1)[0] | |
| st.write("Predicted Label:", id2label[prediction]) | |