noequal commited on
Commit
684f30c
·
1 Parent(s): 3893344

Update app to use internal data

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -1,19 +1,27 @@
1
  import streamlit as st
2
  import torch
3
  from torch.utils.data import Dataset, random_split
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
5
 
6
- # Prompt user to enter clinical text data and corresponding labels
7
- train_texts = st.text_input("Enter your clinical text data (separated by commas):")
8
- train_labels = st.text_input("Enter your corresponding labels (separated by commas):")
 
 
 
 
9
 
10
- # Convert comma-separated values into lists
11
- train_texts = train_texts.split(",")
12
- train_labels = train_labels.split(",")
 
 
 
13
 
14
  # Load pre-trained model and tokenizer
15
- model = AutoModelForCausalLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
16
- tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
 
17
 
18
  # Create PyTorch Dataset object
19
  class ClinicalDataset(Dataset):
@@ -31,6 +39,11 @@ class ClinicalDataset(Dataset):
31
  encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
32
  return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)}
33
 
 
 
 
 
 
34
  dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer)
35
 
36
  # Split dataset into training and validation sets
@@ -48,6 +61,7 @@ training_args = TrainingArguments(
48
  weight_decay=0.01, # strength of weight decay
49
  logging_dir='./logs', # directory for storing logs
50
  logging_steps=10,)
 
51
  trainer = Trainer(
52
  model=model,
53
  args=training_args,
@@ -56,4 +70,13 @@ trainer = Trainer(
56
  data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
57
  'attention_mask': torch.stack([f['attention_mask'] for f in data]),
58
  'labels': torch.stack([f['labels'] for f in data])}, )
 
 
 
59
  trainer.train()
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from torch.utils.data import Dataset, random_split
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, default_data_collator
5
 
6
+ # Generate sample clinical text and labels
7
+ sample_data = [
8
+ ("Had successful surgery today. Feeling relieved.", "surgery"),
9
+ ("Started new medication for pain management.", "non-surgery"),
10
+ ("Scheduled for surgery next week. Nervous but hopeful.", "surgery"),
11
+ ("Attended a seminar on non-surgical treatments.", "non-surgery"),
12
+ ]
13
 
14
+ train_texts, train_labels = zip(*sample_data)
15
+
16
+ # Logging and Outputs
17
+ st.write("Sample data:")
18
+ for text, label in zip(train_texts, train_labels):
19
+ st.write(f"Text: {text}\nLabel: {label}\n")
20
 
21
  # Load pre-trained model and tokenizer
22
+ model_name = "distilbert-base-uncased" # You can use any suitable classification model
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
 
26
  # Create PyTorch Dataset object
27
  class ClinicalDataset(Dataset):
 
39
  encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
40
  return {"input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": torch.tensor(label)}
41
 
42
+
43
+
44
+ # Data Collator
45
+ data_collator = default_data_collator
46
+
47
  dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer)
48
 
49
  # Split dataset into training and validation sets
 
61
  weight_decay=0.01, # strength of weight decay
62
  logging_dir='./logs', # directory for storing logs
63
  logging_steps=10,)
64
+
65
  trainer = Trainer(
66
  model=model,
67
  args=training_args,
 
70
  data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
71
  'attention_mask': torch.stack([f['attention_mask'] for f in data]),
72
  'labels': torch.stack([f['labels'] for f in data])}, )
73
+
74
+
75
+ st.write("Training started...")
76
  trainer.train()
77
+ st.write("Training completed.")
78
+
79
+ # Logging Training Output
80
+ st.write("Training logs:")
81
+ with open('./logs/train.log', 'r') as log_file:
82
+ st.code(log_file.read())