Spaces:
Runtime error
Runtime error
File size: 2,092 Bytes
2349074 84fe4f3 97cdd73 84fe4f3 97cdd73 84fe4f3 684f30c 84fe4f3 97cdd73 84fe4f3 3893344 84fe4f3 9128ec6 684f30c 84fe4f3 97cdd73 84fe4f3 97cdd73 84fe4f3 098f3a5 84fe4f3 684f30c 84fe4f3 098f3a5 001896c 84fe4f3 001896c 684f30c 84fe4f3 684f30c 84fe4f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import streamlit as st
import numpy as np
import evaluate
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
# Define label mappings
id2label = {0: "SURGERY", 1: "NON-SURGERY"}
label2id = {"SURGERY": 0, "NON-SURGERY": 1}
# Load evaluation metric
accuracy = evaluate.load("accuracy")
# Define preprocessing function
def preprocess_function(examples):
return tokenizer(examples, truncation=True, padding=True)
# Load model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(
"emilyalsentzer/Bio_ClinicalBERT", num_labels=2, id2label=id2label, label2id=label2id
)
# Define compute_metrics function
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# Define data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Define training arguments
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Streamlit UI
st.title("Clinical Text Classification")
text = st.text_area("Enter clinical text:", "")
if st.button("Classify"):
# Tokenize user input and predict
tokenized_text = preprocess_function(text)
result = trainer.predict(tokenized_text)
prediction = np.argmax(result.predictions, axis=1)[0]
st.write("Predicted Label:", id2label[prediction])
|