noequal commited on
Commit
9128ec6
·
1 Parent(s): af8d075

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  import torch
3
  from torch.utils.data import Dataset, random_split
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, default_data_collator
5
 
6
  # Generate sample clinical text and labels
7
  sample_data = [
@@ -58,7 +58,10 @@ class ClinicalDataset(Dataset):
58
 
59
 
60
  # Data Collator
61
- data_collator = default_data_collator
 
 
 
62
 
63
  seq_length = 128
64
  dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer, max_seq_length=seq_length)
@@ -84,11 +87,7 @@ trainer = Trainer(
84
  args=training_args,
85
  train_dataset=train_dataset,
86
  eval_dataset=val_dataset,
87
- data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
88
- 'attention_mask': torch.stack([f['attention_mask'] for f in data]),
89
- 'labels': torch.stack([f['labels'] for f in data])},
90
- pad_to_max_length=True
91
-
92
  )
93
 
94
 
 
1
  import streamlit as st
2
  import torch
3
  from torch.utils.data import Dataset, random_split
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
 
6
  # Generate sample clinical text and labels
7
  sample_data = [
 
58
 
59
 
60
  # Data Collator
61
+ data_collator = DataCollatorForLanguageModeling(
62
+ tokenizer=tokenizer,
63
+ mlm_probability=0.15
64
+ )
65
 
66
  seq_length = 128
67
  dataset = ClinicalDataset(texts=train_texts, labels=train_labels, tokenizer=tokenizer, max_seq_length=seq_length)
 
87
  args=training_args,
88
  train_dataset=train_dataset,
89
  eval_dataset=val_dataset,
90
+ data_collator=data_collator,
 
 
 
 
91
  )
92
 
93